NEML2 1.4.0
|
#include <VariableStore.h>
Protected Member Functions | |
virtual void | cache (TensorShapeRef batch_shape) |
Cache the variable's batch shape. | |
virtual void | allocate_variables (TensorShapeRef batch_shape, const torch::TensorOptions &options, bool in, bool out, bool dout_din, bool d2out_din2) |
Allocate variable storages given the batch shape and tensor options. | |
virtual void | setup_input_views (VariableStore *host=nullptr) |
Tell each input variable view which tensor storage(s) to view into. | |
virtual void | setup_output_views (bool out, bool dout_din, bool d2out_din2) |
Tell each output variable view which tensor storage(s) to view into. | |
virtual void | zero (bool dout_din, bool d2out_din2) |
Zero out derivative and second derivative storage. | |
template<typename T , typename... S> | |
const Variable< T > & | declare_input_variable (S &&... name) |
Declare an input variable. | |
template<typename... S> | |
const Variable< Tensor > & | declare_input_variable (Size sz, TensorType t, S &&... name) |
Declare an input variable (with unknown base shape at compile time) | |
template<typename T , typename... S> | |
const Variable< Tensor > & | declare_input_variable_list (Size list_size, S &&... name) |
Declare an input variable that is a list of tensors of fixed size. | |
template<typename T , typename... S> | |
Variable< T > & | declare_output_variable (S &&... name) |
Declare an output variable. | |
template<typename... S> | |
Variable< Tensor > & | declare_output_variable (Size sz, TensorType t, S &&... name) |
Declare an input variable (with unknown base shape at compile time) | |
template<typename T , typename... S> | |
Variable< Tensor > & | declare_output_variable_list (Size list_size, S &&... name) |
Declare an output variable that is a list of tensors of fixed size. | |
template<typename T > | |
VariableName | declare_variable (LabeledAxis &axis, const VariableName &var) const |
Declare an item recursively on an axis. | |
VariableName | declare_variable (LabeledAxis &axis, const VariableName &var, Size sz) const |
Declare an item (with known storage size) recursively on an axis. | |
VariableName | declare_subaxis (LabeledAxis &axis, const VariableName &subaxis) const |
Declare a subaxis recursively on an axis. | |
VariableStore | ( | const OptionSet & | options, |
Model * | object ) |
|
protectedvirtual |
Allocate variable storages given the batch shape and tensor options.
batch_shape | Batch shape of the allocated tensors |
options | Tensor options of the allocated tensors |
in | Whether to allocate tensor storage for input |
out | Whether to allocate tensor storage for output |
dout_din | Whether to allocate tensor storage for the first derivatives |
d2out_din2 | Whether to allocate tensor storage for the second derivatives |
Reimplemented in Model.
|
protectedvirtual |
Cache the variable's batch shape.
Reimplemented in Model.
LabeledAxis & declare_axis | ( | const std::string & | name | ) |
|
inlineprotected |
Declare an input variable.
|
inlineprotected |
Declare an input variable (with unknown base shape at compile time)
|
inlineprotected |
Declare an input variable that is a list of tensors of fixed size.
|
inlineprotected |
Declare an output variable.
|
inlineprotected |
Declare an input variable (with unknown base shape at compile time)
|
inlineprotected |
Declare an output variable that is a list of tensors of fixed size.
|
inlineprotected |
Declare a subaxis recursively on an axis.
|
inlineprotected |
Declare an item recursively on an axis.
|
inlineprotected |
Declare an item (with known storage size) recursively on an axis.
|
inline |
Derivative storage
|
inline |
|
inline |
Get an input variable
|
inline |
|
inline |
Get an output variable
|
inline |
|
inline |
Definition of the input variables
|
inline |
|
inline |
Input storage
|
inline |
TensorType input_type | ( | const VariableName & | name | ) | const |
Get the variable type of an input variable.
VariableBase * input_variable | ( | const VariableName & | name | ) |
Get the view of an input variable.
|
inline |
Input variable views
|
inline |
|
inline |
Which variables this object defines as output
|
inline |
|
inline |
Output storage
|
inline |
TensorType output_type | ( | const VariableName & | name | ) | const |
Get the variable type of an output variable.
VariableBase * output_variable | ( | const VariableName & | name | ) |
Get the view of an output variable.
|
inline |
Output variable views
|
inline |
|
inline |
Second derivative storage
|
inline |
|
protectedvirtual |
Tell each input variable view which tensor storage(s) to view into.
Reimplemented in Model.
|
virtual |
Setup the layouts of all the registered axes.
Tell each output variable view which tensor storage(s) to view into.
Reimplemented in Model.