25#include "neml2/models/VariableStore.h"
32 _input_axis(declare_axis(
"input")),
33 _output_axis(declare_axis(
"output"))
41 "Trying to declare an axis named ",
43 ", but an axis with the same name already exists.");
45 auto axis = std::make_unique<LabeledAxis>();
46 return *_axes.set_pointer(name, std::move(axis));
59 return _input_views.query_value(name);
65 return _output_views.query_value(name);
79 const torch::TensorOptions & options,
86 if (
in && _object->
host() == _object)
117 var.setup_views(&_out, &_dout_din, &_d2out_din2);
124 var.reinit_views(
true,
false,
false);
153 if (_dout_din.
tensor().requires_grad())
155 _dout_din.
tensor().detach_();
163 if (_d2out_din2.
tensor().requires_grad())
165 _d2out_din2.
tensor().detach_();
166 d2out_din2_detached =
true;
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:44
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
void setup_layout()
Definition LabeledAxis.cxx:158
void zero_()
Zero out this tensor.
Definition LabeledTensor.cxx:152
static LabeledVector zeros(TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new storage with zeros.
Definition LabeledTensor.cxx:109
const BatchTensor & tensor() const
Definition LabeledTensor.h:107
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:38
const T * host() const
Get a readonly pointer to the host.
Definition NEML2Object.h:90
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:59
Definition VariableStore.h:37
virtual void setup_input_views()
Tell each input variable view which tensor storage(s) to view into.
Definition VariableStore.cxx:107
Storage< VariableName, VariableBase > & output_views()
Definition VariableStore.h:109
virtual void reinit_output_views(bool out, bool dout_din=true, bool d2out_din2=true)
Create the views for output variables, and optionally for the derivative and second derivatives.
Definition VariableStore.cxx:128
LabeledAxis & output_axis()
Definition VariableStore.h:97
VariableBase * input_view(const VariableName &)
Get the view of an input variable.
Definition VariableStore.cxx:57
virtual void setup_output_views()
Tell each output variable view which tensor storage(s) to view into.
Definition VariableStore.cxx:114
virtual void allocate_variables(TorchShapeRef batch_shape, const torch::TensorOptions &options, bool in, bool out, bool dout_din, bool d2out_din2)
Allocate variable storages given the batch shape and tensor options.
Definition VariableStore.cxx:78
virtual void detach_and_zero(bool out, bool dout_din=true, bool d2out_din2=true)
Detach the tensor storages and set each element in the tensor to 0.
Definition VariableStore.cxx:135
Storage< VariableName, VariableBase > & input_views()
Definition VariableStore.h:103
virtual void setup_layout()
Setup the layouts of all the registered axes.
Definition VariableStore.cxx:50
LabeledAxis & declare_axis(const std::string &name)
Definition VariableStore.cxx:38
VariableBase * output_view(const VariableName &)
Get the view of an output variable.
Definition VariableStore.cxx:63
LabeledVector & input_storage()
Definition VariableStore.h:115
virtual void cache(TorchShapeRef batch_shape)
Cache the variable's batch shape.
Definition VariableStore.cxx:69
LabeledAxis & input_axis()
Definition VariableStore.h:91
virtual void reinit_input_views()
Create the views for input variables.
Definition VariableStore.cxx:121
VariableStore(const OptionSet &options, NEML2Object *object)
Definition VariableStore.cxx:29
Definition CrossRef.cxx:32
torch::IntArrayRef TorchShapeRef
Definition types.h:37
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73