NEML2 1.4.0
|
#include <VariableStore.h>
Protected Member Functions | |
virtual void | cache (TorchShapeRef batch_shape) |
Cache the variable's batch shape. | |
virtual void | allocate_variables (TorchShapeRef batch_shape, const torch::TensorOptions &options, bool in, bool out, bool dout_din, bool d2out_din2) |
Allocate variable storages given the batch shape and tensor options. | |
virtual void | setup_input_views () |
Tell each input variable view which tensor storage(s) to view into. | |
virtual void | setup_output_views () |
Tell each output variable view which tensor storage(s) to view into. | |
virtual void | reinit_input_views () |
Create the views for input variables. | |
virtual void | reinit_output_views (bool out, bool dout_din=true, bool d2out_din2=true) |
Create the views for output variables, and optionally for the derivative and second derivatives. | |
virtual void | detach_and_zero (bool out, bool dout_din=true, bool d2out_din2=true) |
Detach the tensor storages and set each element in the tensor to 0. | |
template<typename T , typename... S> | |
const Variable< T > & | declare_input_variable (S &&... name) |
Declare an input variable. | |
template<typename... S> | |
const Variable< BatchTensor > & | declare_input_variable (TorchSize sz, S &&... name) |
Declare an input variable (with unknown base shape at compile time) | |
template<typename T , typename... S> | |
const Variable< BatchTensor > & | declare_input_variable_list (TorchSize list_size, S &&... name) |
Declare an input variable that is a list of tensors of fixed size. | |
template<typename T , typename... S> | |
Variable< T > & | declare_output_variable (S &&... name) |
Declare an output variable. | |
template<typename... S> | |
Variable< BatchTensor > & | declare_output_variable (TorchSize sz, S &&... name) |
Declare an input variable (with unknown base shape at compile time) | |
template<typename T , typename... S> | |
Variable< BatchTensor > & | declare_output_variable_list (TorchSize list_size, S &&... name) |
Declare an output variable that is a list of tensors of fixed size. | |
template<typename T > | |
VariableName | declare_variable (LabeledAxis &axis, const VariableName &var) const |
Declare an item recursively on an axis. | |
VariableName | declare_variable (LabeledAxis &axis, const VariableName &var, TorchSize sz) const |
Declare an item (with known storage size) recursively on an axis. | |
VariableName | declare_subaxis (LabeledAxis &axis, const VariableName &subaxis) const |
Declare a subaxis recursively on an axis. | |
VariableStore | ( | const OptionSet & | options, |
NEML2Object * | object ) |
|
protectedvirtual |
Allocate variable storages given the batch shape and tensor options.
batch_shape | Batch shape of the allocated tensors |
options | Tensor options of the allocated tensors |
in | Whether to allocate tensor storage for input |
out | Whether to allocate tensor storage for output |
dout_din | Whether to allocate tensor storage for the first derivatives |
d2out_din2 | Whether to allocate tensor storage for the second derivatives |
Reimplemented in Model.
|
protectedvirtual |
LabeledAxis & declare_axis | ( | const std::string & | name | ) |
|
inlineprotected |
Declare an input variable.
|
inlineprotected |
Declare an input variable (with unknown base shape at compile time)
|
inlineprotected |
Declare an input variable that is a list of tensors of fixed size.
|
inlineprotected |
Declare an output variable.
|
inlineprotected |
Declare an input variable (with unknown base shape at compile time)
|
inlineprotected |
Declare an output variable that is a list of tensors of fixed size.
|
inlineprotected |
Declare a subaxis recursively on an axis.
|
inlineprotected |
Declare an item recursively on an axis.
|
inlineprotected |
Declare an item (with known storage size) recursively on an axis.
|
inline |
Derivative storage
|
inline |
Detach the tensor storages and set each element in the tensor to 0.
Reimplemented in Model.
|
inline |
Get an input variable
|
inline |
|
inline |
Get an output variable
|
inline |
|
inline |
Definition of the input variables
|
inline |
|
inline |
Input storage
|
inline |
VariableBase * input_view | ( | const VariableName & | name | ) |
Get the view of an input variable.
|
inline |
Input variable views
|
inline |
|
inline |
Which variables this object defines as output
|
inline |
|
inline |
Output storage
|
inline |
VariableBase * output_view | ( | const VariableName & | name | ) |
Get the view of an output variable.
|
inline |
Output variable views
|
inline |
|
protectedvirtual |
Create the views for input variables.
Reimplemented in Model.
|
protectedvirtual |
Create the views for output variables, and optionally for the derivative and second derivatives.
Reimplemented in Model.
|
inline |
Second derivative storage
|
inline |
|
protectedvirtual |
Tell each input variable view which tensor storage(s) to view into.
Reimplemented in Model.
|
virtual |
Setup the layouts of all the registered axes.