27#include "neml2/base/NEML2Object.h"
28#include "neml2/base/Storage.h"
29#include "neml2/tensors/Variable.h"
30#include "neml2/tensors/LabeledVector.h"
31#include "neml2/tensors/LabeledMatrix.h"
32#include "neml2/tensors/LabeledTensor3D.h"
48 template <
typename T = BatchTensor>
55 var_ptr,
"Input variable ", name,
" exist but cannot be cast to the requested type.");
58 template <
typename T = BatchTensor>
61 const auto var_base_ptr = _input_views.query_value(name);
65 var_ptr,
"Input variable ", name,
" exist but cannot be cast to the requested type.");
72 template <
typename T = BatchTensor>
75 return std::as_const(*this).get_output_variable<T>(name);
77 template <
typename T = BatchTensor>
80 const auto var_base_ptr = _output_views.query_value(name);
84 var_ptr,
"Output variable ", name,
" exist but cannot be cast to the requested type.");
157 const torch::TensorOptions & options,
179 template <
typename T,
typename...
S>
182 const auto var_name = variable_name(std::forward<S>(name)...);
188 template <
typename...
S>
191 const auto var_name = variable_name(std::forward<S>(name)...);
197 template <
typename T,
typename...
S>
204 template <
typename T,
typename...
S>
207 const auto var_name = variable_name(std::forward<S>(name)...);
213 template <
typename...
S>
216 const auto var_name = variable_name(std::forward<S>(name)...);
222 template <
typename T,
typename...
S>
229 template <
typename T>
251 template <
typename...
S>
254 using FirstType = std::tuple_element_t<0, std::tuple<
S...>>;
256 if constexpr (
sizeof...(name) == 1 && std::is_convertible_v<FirstType, std::string>)
267 template <
typename T>
268 Variable<T> * create_variable_view(Storage<VariableName, VariableBase> & views,
272 if constexpr (std::is_same_v<T, BatchTensor>)
273 neml_assert(sz > 0,
"Allocating a BatchTensor requires a known storage size.");
276 VariableBase * var_base_ptr = views.query_value(name);
278 "Trying to allocate variable ",
280 ", but a variable with the same name already exists.");
283 if constexpr (std::is_same_v<T, BatchTensor>)
285 auto var = std::make_unique<Variable<BatchTensor>>(name, sz);
286 var_base_ptr = views.set_pointer(name, std::move(var));
290 auto var = std::make_unique<Variable<T>>(name);
291 var_base_ptr = views.set_pointer(name, std::move(var));
295 auto var_ptr =
dynamic_cast<Variable<T> *
>(var_base_ptr);
297 var_ptr,
"Internal error: Failed to cast variable ", name,
" to its concrete type.");
302 NEML2Object * _object;
309 const OptionSet _options;
312 Storage<std::string, LabeledAxis> _axes;
315 Storage<VariableName, VariableBase> _input_views;
318 Storage<VariableName, VariableBase> _output_views;
321 LabeledAxis & _input_axis;
324 LabeledAxis & _output_axis;
333 LabeledMatrix _dout_din;
336 LabeledTensor3D _d2out_din2;
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:44
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
LabeledAxis & add(const LabeledAxisAccessor &accessor)
Add a variable or subaxis.
Definition LabeledAxis.h:67
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
A single-batched, logically 3D LabeledTensor.
Definition LabeledTensor3D.h:38
A single-batched, logically 1D LabeledTensor.
Definition LabeledVector.h:38
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:38
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:59
Definition VariableStore.h:37
virtual void setup_input_views()
Tell each input variable view which tensor storage(s) to view into.
Definition VariableStore.cxx:107
const Variable< T > & get_input_variable(const VariableName &name) const
Definition VariableStore.h:59
const Variable< T > & get_output_variable(const VariableName &name) const
Definition VariableStore.h:78
LabeledMatrix & derivative_storage()
Definition VariableStore.h:127
const LabeledVector & output_storage() const
Definition VariableStore.h:122
const Storage< VariableName, VariableBase > & output_views() const
Definition VariableStore.h:110
const Variable< T > & get_output_variable(const VariableName &name)
Definition VariableStore.h:73
Storage< VariableName, VariableBase > & output_views()
Definition VariableStore.h:109
Variable< T > & declare_output_variable(S &&... name)
Declare an output variable.
Definition VariableStore.h:205
LabeledVector & output_storage()
Definition VariableStore.h:121
const LabeledTensor3D & second_derivative_storage() const
Definition VariableStore.h:134
VariableName declare_variable(LabeledAxis &axis, const VariableName &var, TorchSize sz) const
Declare an item (with known storage size) recursively on an axis.
Definition VariableStore.h:236
Variable< BatchTensor > & declare_output_variable(TorchSize sz, S &&... name)
Declare an input variable (with unknown base shape at compile time)
Definition VariableStore.h:214
const Variable< T > & declare_input_variable(S &&... name)
Declare an input variable.
Definition VariableStore.h:180
const Variable< BatchTensor > & declare_input_variable_list(TorchSize list_size, S &&... name)
Declare an input variable that is a list of tensors of fixed size.
Definition VariableStore.h:198
virtual void reinit_output_views(bool out, bool dout_din=true, bool d2out_din2=true)
Create the views for output variables, and optionally for the derivative and second derivatives.
Definition VariableStore.cxx:128
LabeledAxis & output_axis()
Definition VariableStore.h:97
VariableBase * input_view(const VariableName &)
Get the view of an input variable.
Definition VariableStore.cxx:57
const Variable< BatchTensor > & declare_input_variable(TorchSize sz, S &&... name)
Declare an input variable (with unknown base shape at compile time)
Definition VariableStore.h:189
Variable< BatchTensor > & declare_output_variable_list(TorchSize list_size, S &&... name)
Declare an output variable that is a list of tensors of fixed size.
Definition VariableStore.h:223
virtual void setup_output_views()
Tell each output variable view which tensor storage(s) to view into.
Definition VariableStore.cxx:114
VariableName declare_variable(LabeledAxis &axis, const VariableName &var) const
Declare an item recursively on an axis.
Definition VariableStore.h:230
const Storage< VariableName, VariableBase > & input_views() const
Definition VariableStore.h:104
virtual void allocate_variables(TorchShapeRef batch_shape, const torch::TensorOptions &options, bool in, bool out, bool dout_din, bool d2out_din2)
Allocate variable storages given the batch shape and tensor options.
Definition VariableStore.cxx:78
virtual void detach_and_zero(bool out, bool dout_din=true, bool d2out_din2=true)
Detach the tensor storages and set each element in the tensor to 0.
Definition VariableStore.cxx:135
Storage< VariableName, VariableBase > & input_views()
Definition VariableStore.h:103
const LabeledAxis & input_axis() const
Definition VariableStore.h:92
const LabeledMatrix & derivative_storage() const
Definition VariableStore.h:128
virtual void setup_layout()
Setup the layouts of all the registered axes.
Definition VariableStore.cxx:50
LabeledAxis & declare_axis(const std::string &name)
Definition VariableStore.cxx:38
const LabeledVector & input_storage() const
Definition VariableStore.h:116
VariableBase * output_view(const VariableName &)
Get the view of an output variable.
Definition VariableStore.cxx:63
VariableName declare_subaxis(LabeledAxis &axis, const VariableName &subaxis) const
Declare a subaxis recursively on an axis.
Definition VariableStore.h:243
const LabeledAxis & output_axis() const
Definition VariableStore.h:98
LabeledVector & input_storage()
Definition VariableStore.h:115
virtual void cache(TorchShapeRef batch_shape)
Cache the variable's batch shape.
Definition VariableStore.cxx:69
LabeledTensor3D & second_derivative_storage()
Definition VariableStore.h:133
LabeledAxis & input_axis()
Definition VariableStore.h:91
virtual void reinit_input_views()
Create the views for input variables.
Definition VariableStore.cxx:121
Variable< T > & get_input_variable(const VariableName &name)
Definition VariableStore.h:49
VariableStore(const OptionSet &options, NEML2Object *object)
Definition VariableStore.cxx:29
Definition CrossRef.cxx:32
int64_t TorchSize
Definition types.h:35
torch::IntArrayRef TorchShapeRef
Definition types.h:37
LabeledAxisAccessor VariableName
Definition Variable.h:35
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73