27#include "neml2/base/NEML2Object.h"
28#include "neml2/base/Storage.h"
29#include "neml2/tensors/Variable.h"
30#include "neml2/tensors/LabeledVector.h"
31#include "neml2/tensors/LabeledMatrix.h"
32#include "neml2/tensors/LabeledTensor3D.h"
51 template <
typename T = Tensor>
58 var_ptr,
"Input variable ", name,
" exist but cannot be cast to the requested type.");
61 template <
typename T = Tensor>
64 const auto var_base_ptr = _input_views.query_value(name);
68 var_ptr,
"Input variable ", name,
" exist but cannot be cast to the requested type.");
75 template <
typename T = Tensor>
78 return std::as_const(*this).get_output_variable<T>(name);
80 template <
typename T = Tensor>
83 const auto var_base_ptr = _output_views.query_value(name);
87 var_ptr,
"Output variable ", name,
" exist but cannot be cast to the requested type.");
165 const torch::TensorOptions & options,
181 template <
typename T,
typename...
S>
184 const auto var_name = variable_name(std::forward<S>(name)...);
190 template <
typename...
S>
193 const auto var_name = variable_name(std::forward<S>(name)...);
199 template <
typename T,
typename...
S>
203 list_size * T::const_base_storage, TensorType::kTensor, std::forward<S>(name)...);
207 template <
typename T,
typename...
S>
210 const auto var_name = variable_name(std::forward<S>(name)...);
216 template <
typename...
S>
219 const auto var_name = variable_name(std::forward<S>(name)...);
225 template <
typename T,
typename...
S>
229 list_size * T::const_base_storage, TensorType::kTensor, std::forward<S>(name)...);
233 template <
typename T>
255 template <
typename...
S>
258 using FirstType = std::tuple_element_t<0, std::tuple<
S...>>;
260 if constexpr (
sizeof...(name) == 1 && std::is_convertible_v<FirstType, std::string>)
271 template <
typename T>
272 Variable<T> * create_variable_view(Storage<VariableName, VariableBase> & views,
277 if constexpr (std::is_same_v<T, Tensor>)
278 neml_assert(sz > 0,
"Allocating a Tensor requires a known storage size.");
281 VariableBase * var_base_ptr = views.query_value(name);
283 "Trying to allocate variable ",
285 ", but a variable with the same name already exists.");
288 if constexpr (std::is_same_v<T, Tensor>)
290 auto var = std::make_unique<Variable<Tensor>>(name, _object, sz, t);
291 var_base_ptr = views.set_pointer(name, std::move(var));
296 auto var = std::make_unique<Variable<T>>(name, _object);
297 var_base_ptr = views.set_pointer(name, std::move(var));
301 auto var_ptr =
dynamic_cast<Variable<T> *
>(var_base_ptr);
303 var_ptr,
"Internal error: Failed to cast variable ", name,
" to its concrete type.");
315 const OptionSet _options;
318 Storage<std::string, LabeledAxis> _axes;
321 Storage<VariableName, VariableBase> _input_views;
324 Storage<VariableName, VariableBase> _output_views;
327 LabeledAxis & _input_axis;
330 LabeledAxis & _output_axis;
339 LabeledMatrix _dout_din;
342 LabeledTensor3D _d2out_din2;
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:47
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
LabeledAxis & add(const LabeledAxisAccessor &accessor)
Add a variable or subaxis.
Definition LabeledAxis.h:79
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
A single-batched, logically 3D LabeledTensor.
Definition LabeledTensor3D.h:38
A single-batched, logically 1D LabeledTensor.
Definition LabeledVector.h:38
The base class for all constitutive models.
Definition Model.h:55
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:100
Definition VariableStore.h:40
virtual void setup_input_views(VariableStore *host=nullptr)
Tell each input variable view which tensor storage(s) to view into.
Definition VariableStore.cxx:119
const Variable< T > & get_input_variable(const VariableName &name) const
Definition VariableStore.h:62
const Variable< T > & get_output_variable(const VariableName &name) const
Definition VariableStore.h:81
LabeledMatrix & derivative_storage()
Definition VariableStore.h:130
const LabeledVector & output_storage() const
Definition VariableStore.h:125
const Variable< T > & get_output_variable(const VariableName &name)
Definition VariableStore.h:76
virtual void setup_output_views(bool out, bool dout_din, bool d2out_din2)
Tell each output variable view which tensor storage(s) to view into.
Definition VariableStore.cxx:134
Variable< T > & declare_output_variable(S &&... name)
Declare an output variable.
Definition VariableStore.h:208
LabeledVector & output_storage()
Definition VariableStore.h:124
const LabeledTensor3D & second_derivative_storage() const
Definition VariableStore.h:137
const Variable< T > & declare_input_variable(S &&... name)
Declare an input variable.
Definition VariableStore.h:182
const Variable< Tensor > & declare_input_variable_list(Size list_size, S &&... name)
Declare an input variable that is a list of tensors of fixed size.
Definition VariableStore.h:200
LabeledAxis & output_axis()
Definition VariableStore.h:100
Variable< Tensor > & declare_output_variable_list(Size list_size, S &&... name)
Declare an output variable that is a list of tensors of fixed size.
Definition VariableStore.h:226
Storage< VariableName, VariableBase > & input_variables()
Definition VariableStore.h:106
Storage< VariableName, VariableBase > & output_variables()
Definition VariableStore.h:112
const Storage< VariableName, VariableBase > & input_variables() const
Definition VariableStore.h:107
const Storage< VariableName, VariableBase > & output_variables() const
Definition VariableStore.h:113
const Variable< Tensor > & declare_input_variable(Size sz, TensorType t, S &&... name)
Declare an input variable (with unknown base shape at compile time)
Definition VariableStore.h:191
virtual void allocate_variables(TensorShapeRef batch_shape, const torch::TensorOptions &options, bool in, bool out, bool dout_din, bool d2out_din2)
Allocate variable storages given the batch shape and tensor options.
Definition VariableStore.cxx:95
VariableName declare_variable(LabeledAxis &axis, const VariableName &var, Size sz) const
Declare an item (with known storage size) recursively on an axis.
Definition VariableStore.h:240
VariableName declare_variable(LabeledAxis &axis, const VariableName &var) const
Declare an item recursively on an axis.
Definition VariableStore.h:234
Variable< Tensor > & declare_output_variable(Size sz, TensorType t, S &&... name)
Declare an input variable (with unknown base shape at compile time)
Definition VariableStore.h:217
const LabeledAxis & input_axis() const
Definition VariableStore.h:95
const LabeledMatrix & derivative_storage() const
Definition VariableStore.h:131
virtual void setup_layout()
Setup the layouts of all the registered axes.
Definition VariableStore.cxx:51
VariableBase * input_variable(const VariableName &)
Get the view of an input variable.
Definition VariableStore.cxx:58
LabeledAxis & declare_axis(const std::string &name)
Definition VariableStore.cxx:39
const LabeledVector & input_storage() const
Definition VariableStore.h:119
VariableName declare_subaxis(LabeledAxis &axis, const VariableName &subaxis) const
Declare a subaxis recursively on an axis.
Definition VariableStore.h:247
virtual void zero(bool dout_din, bool d2out_din2)
Zero out derivative and second derivative storage.
Definition VariableStore.cxx:143
TensorType output_type(const VariableName &) const
Get the variable type of an output variable.
Definition VariableStore.cxx:78
const LabeledAxis & output_axis() const
Definition VariableStore.h:101
LabeledVector & input_storage()
Definition VariableStore.h:118
LabeledTensor3D & second_derivative_storage()
Definition VariableStore.h:136
LabeledAxis & input_axis()
Definition VariableStore.h:94
TensorType input_type(const VariableName &) const
Get the variable type of an input variable.
Definition VariableStore.cxx:70
VariableStore(const OptionSet &options, Model *object)
Definition VariableStore.cxx:30
virtual void cache(TensorShapeRef batch_shape)
Cache the variable's batch shape.
Definition VariableStore.cxx:86
VariableBase * output_variable(const VariableName &)
Get the view of an output variable.
Definition VariableStore.cxx:64
Variable< T > & get_input_variable(const VariableName &name)
Definition VariableStore.h:52
Definition CrossRef.cxx:30
LabeledAxisAccessor VariableName
Definition parser_utils.h:33
int64_t Size
Definition types.h:33
TensorType
Definition tensors.h:57
torch::IntArrayRef TensorShapeRef
Definition types.h:35
void neml_assert(bool assertion, Args &&... args)
Definition error.h:64
static constexpr TensorType value
Definition tensors.h:65