NEML2 1.4.0
Loading...
Searching...
No Matches
VariableStore.h
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/base/NEML2Object.h"
28#include "neml2/base/Storage.h"
29#include "neml2/tensors/Variable.h"
30#include "neml2/tensors/LabeledVector.h"
31#include "neml2/tensors/LabeledMatrix.h"
32#include "neml2/tensors/LabeledTensor3D.h"
33
34namespace neml2
35{
37{
38public:
39 VariableStore(const OptionSet & options, NEML2Object * object);
40
41 LabeledAxis & declare_axis(const std::string & name);
42
44 virtual void setup_layout();
45
48 template <typename T = BatchTensor>
50 {
51 auto var_base_ptr = _input_views.query_value(name);
52 neml_assert(var_base_ptr, "Input variable ", name, " does not exist.");
53 auto var_ptr = dynamic_cast<Variable<T> *>(var_base_ptr);
55 var_ptr, "Input variable ", name, " exist but cannot be cast to the requested type.");
56 return *var_ptr;
57 }
58 template <typename T = BatchTensor>
59 const Variable<T> & get_input_variable(const VariableName & name) const
60 {
61 const auto var_base_ptr = _input_views.query_value(name);
62 neml_assert(var_base_ptr, "Input variable ", name, " does not exist.");
63 const auto var_ptr = dynamic_cast<const Variable<T> *>(var_base_ptr);
65 var_ptr, "Input variable ", name, " exist but cannot be cast to the requested type.");
66 return *var_ptr;
67 }
69
72 template <typename T = BatchTensor>
74 {
75 return std::as_const(*this).get_output_variable<T>(name);
76 }
77 template <typename T = BatchTensor>
78 const Variable<T> & get_output_variable(const VariableName & name) const
79 {
80 const auto var_base_ptr = _output_views.query_value(name);
81 neml_assert(var_base_ptr, "Output variable ", name, " does not exist.");
82 const auto var_ptr = dynamic_cast<const Variable<T> *>(var_base_ptr);
84 var_ptr, "Output variable ", name, " exist but cannot be cast to the requested type.");
85 return *var_ptr;
86 }
88
91 LabeledAxis & input_axis() { return _input_axis; }
92 const LabeledAxis & input_axis() const { return _input_axis; }
94
97 LabeledAxis & output_axis() { return _output_axis; }
98 const LabeledAxis & output_axis() const { return _output_axis; }
100
104 const Storage<VariableName, VariableBase> & input_views() const { return _input_views; }
106
110 const Storage<VariableName, VariableBase> & output_views() const { return _output_views; }
112
115 LabeledVector & input_storage() { return _in; }
116 const LabeledVector & input_storage() const { return _in; }
118
121 LabeledVector & output_storage() { return _out; }
122 const LabeledVector & output_storage() const { return _out; }
124
127 LabeledMatrix & derivative_storage() { return _dout_din; }
128 const LabeledMatrix & derivative_storage() const { return _dout_din; }
130
133 LabeledTensor3D & second_derivative_storage() { return _d2out_din2; }
134 const LabeledTensor3D & second_derivative_storage() const { return _d2out_din2; }
136
141
142protected:
144 virtual void cache(TorchShapeRef batch_shape);
145
157 const torch::TensorOptions & options,
158 bool in,
159 bool out,
160 bool dout_din,
161 bool d2out_din2);
162
164 virtual void setup_input_views();
165
167 virtual void setup_output_views();
168
170 virtual void reinit_input_views();
171
173 virtual void reinit_output_views(bool out, bool dout_din = true, bool d2out_din2 = true);
174
176 virtual void detach_and_zero(bool out, bool dout_din = true, bool d2out_din2 = true);
177
179 template <typename T, typename... S>
181 {
182 const auto var_name = variable_name(std::forward<S>(name)...);
183 declare_variable<T>(_input_axis, var_name);
184 return *create_variable_view<T>(_input_views, var_name);
185 }
186
188 template <typename... S>
190 {
191 const auto var_name = variable_name(std::forward<S>(name)...);
192 declare_variable(_input_axis, var_name, sz);
193 return *create_variable_view<BatchTensor>(_input_views, var_name, sz);
194 }
195
197 template <typename T, typename... S>
199 {
200 return declare_input_variable(list_size * T::const_base_storage, std::forward<S>(name)...);
201 }
202
204 template <typename T, typename... S>
206 {
207 const auto var_name = variable_name(std::forward<S>(name)...);
208 declare_variable<T>(_output_axis, var_name);
209 return *create_variable_view<T>(_output_views, var_name);
210 }
211
213 template <typename... S>
215 {
216 const auto var_name = variable_name(std::forward<S>(name)...);
217 declare_variable(_output_axis, var_name, sz);
218 return *create_variable_view<BatchTensor>(_output_views, var_name, sz);
219 }
220
222 template <typename T, typename... S>
224 {
225 return declare_output_variable(list_size * T::const_base_storage, std::forward<S>(name)...);
226 }
227
229 template <typename T>
231 {
232 return declare_variable(axis, var, T::const_base_storage);
233 }
234
237 {
238 axis.add(var, sz);
239 return var;
240 }
241
244 {
245 axis.add<LabeledAxis>(subaxis);
246 return subaxis;
247 }
248
249private:
250 // Helper method to construct variable name in place
251 template <typename... S>
252 VariableName variable_name(S &&... name) const
253 {
254 using FirstType = std::tuple_element_t<0, std::tuple<S...>>;
255
256 if constexpr (sizeof...(name) == 1 && std::is_convertible_v<FirstType, std::string>)
257 {
258 if (_options.contains<VariableName>(name...))
259 return _options.get<VariableName>(name...);
260 return VariableName(std::forward<S>(name)...);
261 }
262 else
263 return VariableName(std::forward<S>(name)...);
264 }
265
266 // Create a variable view (doesn't setup the view)
267 template <typename T>
268 Variable<T> * create_variable_view(Storage<VariableName, VariableBase> & views,
269 const VariableName & name,
270 TorchSize sz = -1)
271 {
272 if constexpr (std::is_same_v<T, BatchTensor>)
273 neml_assert(sz > 0, "Allocating a BatchTensor requires a known storage size.");
274
275 // Make sure we don't duplicate variable allocation
276 VariableBase * var_base_ptr = views.query_value(name);
277 neml_assert(!var_base_ptr,
278 "Trying to allocate variable ",
279 name,
280 ", but a variable with the same name already exists.");
281
282 // Allocate
283 if constexpr (std::is_same_v<T, BatchTensor>)
284 {
285 auto var = std::make_unique<Variable<BatchTensor>>(name, sz);
286 var_base_ptr = views.set_pointer(name, std::move(var));
287 }
288 else
289 {
290 auto var = std::make_unique<Variable<T>>(name);
291 var_base_ptr = views.set_pointer(name, std::move(var));
292 }
293
294 // Cast it to the concrete type
295 auto var_ptr = dynamic_cast<Variable<T> *>(var_base_ptr);
297 var_ptr, "Internal error: Failed to cast variable ", name, " to its concrete type.");
298
299 return var_ptr;
300 }
301
302 NEML2Object * _object;
303
309 const OptionSet _options;
310
312 Storage<std::string, LabeledAxis> _axes;
313
315 Storage<VariableName, VariableBase> _input_views;
316
318 Storage<VariableName, VariableBase> _output_views;
319
321 LabeledAxis & _input_axis;
322
324 LabeledAxis & _output_axis;
325
327 LabeledVector _in;
328
330 LabeledVector _out;
331
333 LabeledMatrix _dout_din;
334
336 LabeledTensor3D _d2out_din2;
337};
338} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:44
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
LabeledAxis & add(const LabeledAxisAccessor &accessor)
Add a variable or subaxis.
Definition LabeledAxis.h:67
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
A single-batched, logically 3D LabeledTensor.
Definition LabeledTensor3D.h:38
A single-batched, logically 1D LabeledTensor.
Definition LabeledVector.h:38
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:38
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:59
Definition Variable.h:41
Definition VariableStore.h:37
virtual void setup_input_views()
Tell each input variable view which tensor storage(s) to view into.
Definition VariableStore.cxx:107
const Variable< T > & get_input_variable(const VariableName &name) const
Definition VariableStore.h:59
const Variable< T > & get_output_variable(const VariableName &name) const
Definition VariableStore.h:78
LabeledMatrix & derivative_storage()
Definition VariableStore.h:127
const LabeledVector & output_storage() const
Definition VariableStore.h:122
const Storage< VariableName, VariableBase > & output_views() const
Definition VariableStore.h:110
const Variable< T > & get_output_variable(const VariableName &name)
Definition VariableStore.h:73
Storage< VariableName, VariableBase > & output_views()
Definition VariableStore.h:109
Variable< T > & declare_output_variable(S &&... name)
Declare an output variable.
Definition VariableStore.h:205
LabeledVector & output_storage()
Definition VariableStore.h:121
const LabeledTensor3D & second_derivative_storage() const
Definition VariableStore.h:134
VariableName declare_variable(LabeledAxis &axis, const VariableName &var, TorchSize sz) const
Declare an item (with known storage size) recursively on an axis.
Definition VariableStore.h:236
Variable< BatchTensor > & declare_output_variable(TorchSize sz, S &&... name)
Declare an input variable (with unknown base shape at compile time)
Definition VariableStore.h:214
const Variable< T > & declare_input_variable(S &&... name)
Declare an input variable.
Definition VariableStore.h:180
const Variable< BatchTensor > & declare_input_variable_list(TorchSize list_size, S &&... name)
Declare an input variable that is a list of tensors of fixed size.
Definition VariableStore.h:198
virtual void reinit_output_views(bool out, bool dout_din=true, bool d2out_din2=true)
Create the views for output variables, and optionally for the derivative and second derivatives.
Definition VariableStore.cxx:128
LabeledAxis & output_axis()
Definition VariableStore.h:97
VariableBase * input_view(const VariableName &)
Get the view of an input variable.
Definition VariableStore.cxx:57
const Variable< BatchTensor > & declare_input_variable(TorchSize sz, S &&... name)
Declare an input variable (with unknown base shape at compile time)
Definition VariableStore.h:189
Variable< BatchTensor > & declare_output_variable_list(TorchSize list_size, S &&... name)
Declare an output variable that is a list of tensors of fixed size.
Definition VariableStore.h:223
virtual void setup_output_views()
Tell each output variable view which tensor storage(s) to view into.
Definition VariableStore.cxx:114
VariableName declare_variable(LabeledAxis &axis, const VariableName &var) const
Declare an item recursively on an axis.
Definition VariableStore.h:230
const Storage< VariableName, VariableBase > & input_views() const
Definition VariableStore.h:104
virtual void allocate_variables(TorchShapeRef batch_shape, const torch::TensorOptions &options, bool in, bool out, bool dout_din, bool d2out_din2)
Allocate variable storages given the batch shape and tensor options.
Definition VariableStore.cxx:78
virtual void detach_and_zero(bool out, bool dout_din=true, bool d2out_din2=true)
Detach the tensor storages and set each element in the tensor to 0.
Definition VariableStore.cxx:135
Storage< VariableName, VariableBase > & input_views()
Definition VariableStore.h:103
const LabeledAxis & input_axis() const
Definition VariableStore.h:92
const LabeledMatrix & derivative_storage() const
Definition VariableStore.h:128
virtual void setup_layout()
Setup the layouts of all the registered axes.
Definition VariableStore.cxx:50
LabeledAxis & declare_axis(const std::string &name)
Definition VariableStore.cxx:38
const LabeledVector & input_storage() const
Definition VariableStore.h:116
VariableBase * output_view(const VariableName &)
Get the view of an output variable.
Definition VariableStore.cxx:63
VariableName declare_subaxis(LabeledAxis &axis, const VariableName &subaxis) const
Declare a subaxis recursively on an axis.
Definition VariableStore.h:243
const LabeledAxis & output_axis() const
Definition VariableStore.h:98
LabeledVector & input_storage()
Definition VariableStore.h:115
virtual void cache(TorchShapeRef batch_shape)
Cache the variable's batch shape.
Definition VariableStore.cxx:69
LabeledTensor3D & second_derivative_storage()
Definition VariableStore.h:133
LabeledAxis & input_axis()
Definition VariableStore.h:91
virtual void reinit_input_views()
Create the views for input variables.
Definition VariableStore.cxx:121
Variable< T > & get_input_variable(const VariableName &name)
Definition VariableStore.h:49
VariableStore(const OptionSet &options, NEML2Object *object)
Definition VariableStore.cxx:29
Definition CrossRef.cxx:32
int64_t TorchSize
Definition types.h:35
torch::IntArrayRef TorchShapeRef
Definition types.h:37
LabeledAxisAccessor VariableName
Definition Variable.h:35
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73