NEML2 1.4.0
Loading...
Searching...
No Matches
VariableStore.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/base/NEML2Object.h"
28#include "neml2/base/Storage.h"
29#include "neml2/tensors/Variable.h"
30#include "neml2/tensors/LabeledVector.h"
31#include "neml2/tensors/LabeledMatrix.h"
32#include "neml2/tensors/LabeledTensor3D.h"
33
34namespace neml2
35{
36// Foward decl
37class Model;
38
40{
41public:
42 VariableStore(const OptionSet & options, Model * object);
43
44 LabeledAxis & declare_axis(const std::string & name);
45
47 virtual void setup_layout();
48
51 template <typename T = Tensor>
53 {
54 auto var_base_ptr = _input_views.query_value(name);
55 neml_assert(var_base_ptr, "Input variable ", name, " does not exist.");
56 auto var_ptr = dynamic_cast<Variable<T> *>(var_base_ptr);
58 var_ptr, "Input variable ", name, " exist but cannot be cast to the requested type.");
59 return *var_ptr;
60 }
61 template <typename T = Tensor>
62 const Variable<T> & get_input_variable(const VariableName & name) const
63 {
64 const auto var_base_ptr = _input_views.query_value(name);
65 neml_assert(var_base_ptr, "Input variable ", name, " does not exist.");
66 const auto var_ptr = dynamic_cast<const Variable<T> *>(var_base_ptr);
68 var_ptr, "Input variable ", name, " exist but cannot be cast to the requested type.");
69 return *var_ptr;
70 }
72
75 template <typename T = Tensor>
77 {
78 return std::as_const(*this).get_output_variable<T>(name);
79 }
80 template <typename T = Tensor>
81 const Variable<T> & get_output_variable(const VariableName & name) const
82 {
83 const auto var_base_ptr = _output_views.query_value(name);
84 neml_assert(var_base_ptr, "Output variable ", name, " does not exist.");
85 const auto var_ptr = dynamic_cast<const Variable<T> *>(var_base_ptr);
87 var_ptr, "Output variable ", name, " exist but cannot be cast to the requested type.");
88 return *var_ptr;
89 }
91
94 LabeledAxis & input_axis() { return _input_axis; }
95 const LabeledAxis & input_axis() const { return _input_axis; }
97
100 LabeledAxis & output_axis() { return _output_axis; }
101 const LabeledAxis & output_axis() const { return _output_axis; }
103
107 const Storage<VariableName, VariableBase> & input_variables() const { return _input_views; }
109
113 const Storage<VariableName, VariableBase> & output_variables() const { return _output_views; }
115
118 LabeledVector & input_storage() { return _in; }
119 const LabeledVector & input_storage() const { return _in; }
121
124 LabeledVector & output_storage() { return _out; }
125 const LabeledVector & output_storage() const { return _out; }
127
130 LabeledMatrix & derivative_storage() { return _dout_din; }
131 const LabeledMatrix & derivative_storage() const { return _dout_din; }
133
136 LabeledTensor3D & second_derivative_storage() { return _d2out_din2; }
137 const LabeledTensor3D & second_derivative_storage() const { return _d2out_din2; }
139
144
146 TensorType input_type(const VariableName &) const;
148 TensorType output_type(const VariableName &) const;
149
150protected:
152 virtual void cache(TensorShapeRef batch_shape);
153
165 const torch::TensorOptions & options,
166 bool in,
167 bool out,
168 bool dout_din,
169 bool d2out_din2);
170
172 virtual void setup_input_views(VariableStore * host = nullptr);
173
175 virtual void setup_output_views(bool out, bool dout_din, bool d2out_din2);
176
178 virtual void zero(bool dout_din, bool d2out_din2);
179
181 template <typename T, typename... S>
183 {
184 const auto var_name = variable_name(std::forward<S>(name)...);
185 declare_variable<T>(_input_axis, var_name);
186 return *create_variable_view<T>(_input_views, var_name);
187 }
188
190 template <typename... S>
192 {
193 const auto var_name = variable_name(std::forward<S>(name)...);
194 declare_variable(_input_axis, var_name, sz);
195 return *create_variable_view<Tensor>(_input_views, var_name, t, sz);
196 }
197
199 template <typename T, typename... S>
201 {
203 list_size * T::const_base_storage, TensorType::kTensor, std::forward<S>(name)...);
204 }
205
207 template <typename T, typename... S>
209 {
210 const auto var_name = variable_name(std::forward<S>(name)...);
211 declare_variable<T>(_output_axis, var_name);
212 return *create_variable_view<T>(_output_views, var_name);
213 }
214
216 template <typename... S>
218 {
219 const auto var_name = variable_name(std::forward<S>(name)...);
220 declare_variable(_output_axis, var_name, sz);
221 return *create_variable_view<Tensor>(_output_views, var_name, t, sz);
222 }
223
225 template <typename T, typename... S>
227 {
229 list_size * T::const_base_storage, TensorType::kTensor, std::forward<S>(name)...);
230 }
231
233 template <typename T>
235 {
236 return declare_variable(axis, var, T::const_base_storage);
237 }
238
241 {
242 axis.add(var, sz);
243 return var;
244 }
245
248 {
249 axis.add<LabeledAxis>(subaxis);
250 return subaxis;
251 }
252
253private:
254 // Helper method to construct variable name in place
255 template <typename... S>
256 VariableName variable_name(S &&... name) const
257 {
258 using FirstType = std::tuple_element_t<0, std::tuple<S...>>;
259
260 if constexpr (sizeof...(name) == 1 && std::is_convertible_v<FirstType, std::string>)
261 {
262 if (_options.contains<VariableName>(name...))
263 return _options.get<VariableName>(name...);
264 return VariableName(std::forward<S>(name)...);
265 }
266 else
267 return VariableName(std::forward<S>(name)...);
268 }
269
270 // Create a variable view (doesn't setup the view)
271 template <typename T>
272 Variable<T> * create_variable_view(Storage<VariableName, VariableBase> & views,
273 const VariableName & name,
275 Size sz = -1)
276 {
277 if constexpr (std::is_same_v<T, Tensor>)
278 neml_assert(sz > 0, "Allocating a Tensor requires a known storage size.");
279
280 // Make sure we don't duplicate variable allocation
281 VariableBase * var_base_ptr = views.query_value(name);
282 neml_assert(!var_base_ptr,
283 "Trying to allocate variable ",
284 name,
285 ", but a variable with the same name already exists.");
286
287 // Allocate
288 if constexpr (std::is_same_v<T, Tensor>)
289 {
290 auto var = std::make_unique<Variable<Tensor>>(name, _object, sz, t);
291 var_base_ptr = views.set_pointer(name, std::move(var));
292 }
293 else
294 {
295 (void)t;
296 auto var = std::make_unique<Variable<T>>(name, _object);
297 var_base_ptr = views.set_pointer(name, std::move(var));
298 }
299
300 // Cast it to the concrete type
301 auto var_ptr = dynamic_cast<Variable<T> *>(var_base_ptr);
303 var_ptr, "Internal error: Failed to cast variable ", name, " to its concrete type.");
304
305 return var_ptr;
306 }
307
308 Model * _object;
309
315 const OptionSet _options;
316
318 Storage<std::string, LabeledAxis> _axes;
319
321 Storage<VariableName, VariableBase> _input_views;
322
324 Storage<VariableName, VariableBase> _output_views;
325
327 LabeledAxis & _input_axis;
328
330 LabeledAxis & _output_axis;
331
333 LabeledVector _in;
334
336 LabeledVector _out;
337
339 LabeledMatrix _dout_din;
340
342 LabeledTensor3D _d2out_din2;
343};
344} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:47
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
LabeledAxis & add(const LabeledAxisAccessor &accessor)
Add a variable or subaxis.
Definition LabeledAxis.h:79
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
A single-batched, logically 3D LabeledTensor.
Definition LabeledTensor3D.h:38
A single-batched, logically 1D LabeledTensor.
Definition LabeledVector.h:38
The base class for all constitutive models.
Definition Model.h:55
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:100
Definition Variable.h:42
Definition VariableStore.h:40
virtual void setup_input_views(VariableStore *host=nullptr)
Tell each input variable view which tensor storage(s) to view into.
Definition VariableStore.cxx:119
const Variable< T > & get_input_variable(const VariableName &name) const
Definition VariableStore.h:62
const Variable< T > & get_output_variable(const VariableName &name) const
Definition VariableStore.h:81
LabeledMatrix & derivative_storage()
Definition VariableStore.h:130
const LabeledVector & output_storage() const
Definition VariableStore.h:125
const Variable< T > & get_output_variable(const VariableName &name)
Definition VariableStore.h:76
virtual void setup_output_views(bool out, bool dout_din, bool d2out_din2)
Tell each output variable view which tensor storage(s) to view into.
Definition VariableStore.cxx:134
Variable< T > & declare_output_variable(S &&... name)
Declare an output variable.
Definition VariableStore.h:208
LabeledVector & output_storage()
Definition VariableStore.h:124
const LabeledTensor3D & second_derivative_storage() const
Definition VariableStore.h:137
const Variable< T > & declare_input_variable(S &&... name)
Declare an input variable.
Definition VariableStore.h:182
const Variable< Tensor > & declare_input_variable_list(Size list_size, S &&... name)
Declare an input variable that is a list of tensors of fixed size.
Definition VariableStore.h:200
LabeledAxis & output_axis()
Definition VariableStore.h:100
Variable< Tensor > & declare_output_variable_list(Size list_size, S &&... name)
Declare an output variable that is a list of tensors of fixed size.
Definition VariableStore.h:226
Storage< VariableName, VariableBase > & input_variables()
Definition VariableStore.h:106
Storage< VariableName, VariableBase > & output_variables()
Definition VariableStore.h:112
const Storage< VariableName, VariableBase > & input_variables() const
Definition VariableStore.h:107
const Storage< VariableName, VariableBase > & output_variables() const
Definition VariableStore.h:113
const Variable< Tensor > & declare_input_variable(Size sz, TensorType t, S &&... name)
Declare an input variable (with unknown base shape at compile time)
Definition VariableStore.h:191
virtual void allocate_variables(TensorShapeRef batch_shape, const torch::TensorOptions &options, bool in, bool out, bool dout_din, bool d2out_din2)
Allocate variable storages given the batch shape and tensor options.
Definition VariableStore.cxx:95
VariableName declare_variable(LabeledAxis &axis, const VariableName &var, Size sz) const
Declare an item (with known storage size) recursively on an axis.
Definition VariableStore.h:240
VariableName declare_variable(LabeledAxis &axis, const VariableName &var) const
Declare an item recursively on an axis.
Definition VariableStore.h:234
Variable< Tensor > & declare_output_variable(Size sz, TensorType t, S &&... name)
Declare an input variable (with unknown base shape at compile time)
Definition VariableStore.h:217
const LabeledAxis & input_axis() const
Definition VariableStore.h:95
const LabeledMatrix & derivative_storage() const
Definition VariableStore.h:131
virtual void setup_layout()
Setup the layouts of all the registered axes.
Definition VariableStore.cxx:51
VariableBase * input_variable(const VariableName &)
Get the view of an input variable.
Definition VariableStore.cxx:58
LabeledAxis & declare_axis(const std::string &name)
Definition VariableStore.cxx:39
const LabeledVector & input_storage() const
Definition VariableStore.h:119
VariableName declare_subaxis(LabeledAxis &axis, const VariableName &subaxis) const
Declare a subaxis recursively on an axis.
Definition VariableStore.h:247
virtual void zero(bool dout_din, bool d2out_din2)
Zero out derivative and second derivative storage.
Definition VariableStore.cxx:143
TensorType output_type(const VariableName &) const
Get the variable type of an output variable.
Definition VariableStore.cxx:78
const LabeledAxis & output_axis() const
Definition VariableStore.h:101
LabeledVector & input_storage()
Definition VariableStore.h:118
LabeledTensor3D & second_derivative_storage()
Definition VariableStore.h:136
LabeledAxis & input_axis()
Definition VariableStore.h:94
TensorType input_type(const VariableName &) const
Get the variable type of an input variable.
Definition VariableStore.cxx:70
VariableStore(const OptionSet &options, Model *object)
Definition VariableStore.cxx:30
virtual void cache(TensorShapeRef batch_shape)
Cache the variable's batch shape.
Definition VariableStore.cxx:86
VariableBase * output_variable(const VariableName &)
Get the view of an output variable.
Definition VariableStore.cxx:64
Variable< T > & get_input_variable(const VariableName &name)
Definition VariableStore.h:52
Definition CrossRef.cxx:30
LabeledAxisAccessor VariableName
Definition parser_utils.h:33
int64_t Size
Definition types.h:33
TensorType
Definition tensors.h:57
torch::IntArrayRef TensorShapeRef
Definition types.h:35
void neml_assert(bool assertion, Args &&... args)
Definition error.h:64
static constexpr TensorType value
Definition tensors.h:65