NEML2 1.4.0
Loading...
Searching...
No Matches
VariableStore Class Reference

#include <VariableStore.h>

Inheritance diagram for VariableStore:

Public Member Functions

 VariableStore (const OptionSet &options, NEML2Object *object)
 
LabeledAxisdeclare_axis (const std::string &name)
 
virtual void setup_layout ()
 Setup the layouts of all the registered axes.
 
VariableBaseinput_view (const VariableName &)
 Get the view of an input variable.
 
VariableBaseoutput_view (const VariableName &)
 Get the view of an output variable.
 
template<typename T = BatchTensor>
Variable< T > & get_input_variable (const VariableName &name)
 
template<typename T = BatchTensor>
const Variable< T > & get_input_variable (const VariableName &name) const
 
template<typename T = BatchTensor>
const Variable< T > & get_output_variable (const VariableName &name)
 
template<typename T = BatchTensor>
const Variable< T > & get_output_variable (const VariableName &name) const
 
LabeledAxisinput_axis ()
 
const LabeledAxisinput_axis () const
 
LabeledAxisoutput_axis ()
 
const LabeledAxisoutput_axis () const
 
Storage< VariableName, VariableBase > & input_views ()
 
const Storage< VariableName, VariableBase > & input_views () const
 
Storage< VariableName, VariableBase > & output_views ()
 
const Storage< VariableName, VariableBase > & output_views () const
 
LabeledVectorinput_storage ()
 
const LabeledVectorinput_storage () const
 
LabeledVectoroutput_storage ()
 
const LabeledVectoroutput_storage () const
 
LabeledMatrixderivative_storage ()
 
const LabeledMatrixderivative_storage () const
 
LabeledTensor3Dsecond_derivative_storage ()
 
const LabeledTensor3Dsecond_derivative_storage () const
 

Protected Member Functions

virtual void cache (TorchShapeRef batch_shape)
 Cache the variable's batch shape.
 
virtual void allocate_variables (TorchShapeRef batch_shape, const torch::TensorOptions &options, bool in, bool out, bool dout_din, bool d2out_din2)
 Allocate variable storages given the batch shape and tensor options.
 
virtual void setup_input_views ()
 Tell each input variable view which tensor storage(s) to view into.
 
virtual void setup_output_views ()
 Tell each output variable view which tensor storage(s) to view into.
 
virtual void reinit_input_views ()
 Create the views for input variables.
 
virtual void reinit_output_views (bool out, bool dout_din=true, bool d2out_din2=true)
 Create the views for output variables, and optionally for the derivative and second derivatives.
 
virtual void detach_and_zero (bool out, bool dout_din=true, bool d2out_din2=true)
 Detach the tensor storages and set each element in the tensor to 0.
 
template<typename T , typename... S>
const Variable< T > & declare_input_variable (S &&... name)
 Declare an input variable.
 
template<typename... S>
const Variable< BatchTensor > & declare_input_variable (TorchSize sz, S &&... name)
 Declare an input variable (with unknown base shape at compile time)
 
template<typename T , typename... S>
const Variable< BatchTensor > & declare_input_variable_list (TorchSize list_size, S &&... name)
 Declare an input variable that is a list of tensors of fixed size.
 
template<typename T , typename... S>
Variable< T > & declare_output_variable (S &&... name)
 Declare an output variable.
 
template<typename... S>
Variable< BatchTensor > & declare_output_variable (TorchSize sz, S &&... name)
 Declare an input variable (with unknown base shape at compile time)
 
template<typename T , typename... S>
Variable< BatchTensor > & declare_output_variable_list (TorchSize list_size, S &&... name)
 Declare an output variable that is a list of tensors of fixed size.
 
template<typename T >
VariableName declare_variable (LabeledAxis &axis, const VariableName &var) const
 Declare an item recursively on an axis.
 
VariableName declare_variable (LabeledAxis &axis, const VariableName &var, TorchSize sz) const
 Declare an item (with known storage size) recursively on an axis.
 
VariableName declare_subaxis (LabeledAxis &axis, const VariableName &subaxis) const
 Declare a subaxis recursively on an axis.
 

Constructor & Destructor Documentation

◆ VariableStore()

VariableStore ( const OptionSet & options,
NEML2Object * object )

Member Function Documentation

◆ allocate_variables()

void allocate_variables ( TorchShapeRef batch_shape,
const torch::TensorOptions & options,
bool in,
bool out,
bool dout_din,
bool d2out_din2 )
protectedvirtual

Allocate variable storages given the batch shape and tensor options.

Parameters
batch_shapeBatch shape of the allocated tensors
optionsTensor options of the allocated tensors
inWhether to allocate tensor storage for input
outWhether to allocate tensor storage for output
dout_dinWhether to allocate tensor storage for the first derivatives
d2out_din2Whether to allocate tensor storage for the second derivatives

Reimplemented in Model.

◆ cache()

void cache ( TorchShapeRef batch_shape)
protectedvirtual

Cache the variable's batch shape.

Reimplemented in Model, and Model.

◆ declare_axis()

LabeledAxis & declare_axis ( const std::string & name)

◆ declare_input_variable() [1/2]

template<typename T , typename... S>
const Variable< T > & declare_input_variable ( S &&... name)
inlineprotected

Declare an input variable.

◆ declare_input_variable() [2/2]

template<typename... S>
const Variable< BatchTensor > & declare_input_variable ( TorchSize sz,
S &&... name )
inlineprotected

Declare an input variable (with unknown base shape at compile time)

◆ declare_input_variable_list()

template<typename T , typename... S>
const Variable< BatchTensor > & declare_input_variable_list ( TorchSize list_size,
S &&... name )
inlineprotected

Declare an input variable that is a list of tensors of fixed size.

◆ declare_output_variable() [1/2]

template<typename T , typename... S>
Variable< T > & declare_output_variable ( S &&... name)
inlineprotected

Declare an output variable.

◆ declare_output_variable() [2/2]

template<typename... S>
Variable< BatchTensor > & declare_output_variable ( TorchSize sz,
S &&... name )
inlineprotected

Declare an input variable (with unknown base shape at compile time)

◆ declare_output_variable_list()

template<typename T , typename... S>
Variable< BatchTensor > & declare_output_variable_list ( TorchSize list_size,
S &&... name )
inlineprotected

Declare an output variable that is a list of tensors of fixed size.

◆ declare_subaxis()

VariableName declare_subaxis ( LabeledAxis & axis,
const VariableName & subaxis ) const
inlineprotected

Declare a subaxis recursively on an axis.

◆ declare_variable() [1/2]

template<typename T >
VariableName declare_variable ( LabeledAxis & axis,
const VariableName & var ) const
inlineprotected

Declare an item recursively on an axis.

◆ declare_variable() [2/2]

VariableName declare_variable ( LabeledAxis & axis,
const VariableName & var,
TorchSize sz ) const
inlineprotected

Declare an item (with known storage size) recursively on an axis.

◆ derivative_storage() [1/2]

LabeledMatrix & derivative_storage ( )
inline

Derivative storage

◆ derivative_storage() [2/2]

const LabeledMatrix & derivative_storage ( ) const
inline

◆ detach_and_zero()

void detach_and_zero ( bool out,
bool dout_din = true,
bool d2out_din2 = true )
protectedvirtual

Detach the tensor storages and set each element in the tensor to 0.

Reimplemented in Model.

◆ get_input_variable() [1/2]

template<typename T = BatchTensor>
Variable< T > & get_input_variable ( const VariableName & name)
inline

Get an input variable

◆ get_input_variable() [2/2]

template<typename T = BatchTensor>
const Variable< T > & get_input_variable ( const VariableName & name) const
inline

◆ get_output_variable() [1/2]

template<typename T = BatchTensor>
const Variable< T > & get_output_variable ( const VariableName & name)
inline

Get an output variable

◆ get_output_variable() [2/2]

template<typename T = BatchTensor>
const Variable< T > & get_output_variable ( const VariableName & name) const
inline

◆ input_axis() [1/2]

LabeledAxis & input_axis ( )
inline

Definition of the input variables

◆ input_axis() [2/2]

const LabeledAxis & input_axis ( ) const
inline

◆ input_storage() [1/2]

LabeledVector & input_storage ( )
inline

Input storage

◆ input_storage() [2/2]

const LabeledVector & input_storage ( ) const
inline

◆ input_view()

VariableBase * input_view ( const VariableName & name)

Get the view of an input variable.

◆ input_views() [1/2]

Storage< VariableName, VariableBase > & input_views ( )
inline

Input variable views

◆ input_views() [2/2]

const Storage< VariableName, VariableBase > & input_views ( ) const
inline

◆ output_axis() [1/2]

LabeledAxis & output_axis ( )
inline

Which variables this object defines as output

◆ output_axis() [2/2]

const LabeledAxis & output_axis ( ) const
inline

◆ output_storage() [1/2]

LabeledVector & output_storage ( )
inline

Output storage

◆ output_storage() [2/2]

const LabeledVector & output_storage ( ) const
inline

◆ output_view()

VariableBase * output_view ( const VariableName & name)

Get the view of an output variable.

◆ output_views() [1/2]

Storage< VariableName, VariableBase > & output_views ( )
inline

Output variable views

◆ output_views() [2/2]

const Storage< VariableName, VariableBase > & output_views ( ) const
inline

◆ reinit_input_views()

void reinit_input_views ( )
protectedvirtual

Create the views for input variables.

Reimplemented in Model.

◆ reinit_output_views()

void reinit_output_views ( bool out,
bool dout_din = true,
bool d2out_din2 = true )
protectedvirtual

Create the views for output variables, and optionally for the derivative and second derivatives.

Reimplemented in Model.

◆ second_derivative_storage() [1/2]

LabeledTensor3D & second_derivative_storage ( )
inline

Second derivative storage

◆ second_derivative_storage() [2/2]

const LabeledTensor3D & second_derivative_storage ( ) const
inline

◆ setup_input_views()

void setup_input_views ( )
protectedvirtual

Tell each input variable view which tensor storage(s) to view into.

Reimplemented in Model.

◆ setup_layout()

void setup_layout ( )
virtual

Setup the layouts of all the registered axes.

◆ setup_output_views()

void setup_output_views ( )
protectedvirtual

Tell each output variable view which tensor storage(s) to view into.

Reimplemented in Model.