NEML2 1.4.0
Loading...
Searching...
No Matches
VariableBase Class Referenceabstract

#include <Variable.h>

Inheritance diagram for VariableBase:

Public Member Functions

 VariableBase (const VariableName &name_in)
 
virtual ~VariableBase ()=default
 
virtual void cache (TorchShapeRef batch_shape)
 Cache the variable's batch shape.
 
void setup_views (const LabeledVector *value, const LabeledMatrix *deriv=nullptr, const LabeledTensor3D *secderiv=nullptr)
 Setup the variable's views into blocks of the storage.
 
void setup_views (const VariableBase *other)
 Setup the variable's views into blocks of the storage.
 
virtual void reinit_views (bool out, bool dout_din, bool d2out_din2)
 Reinitialize variable views.
 
virtual void requires_grad_ (bool req=true)=0
 Set requires_grad for the underlying storage.
 
const std::vector< VariableName > & args () const
 Arguments.
 
void add_arg (const VariableBase &arg)
 Add an argument.
 
void clear_args ()
 Clear arguments.
 
Derivative d (const VariableBase &x)
 Create a wrapper representing the derivative dy/dx.
 
Derivative d (const VariableBase &x1, const VariableBase &x2)
 Create a wrapper representing the second derivative d2y/dx2.
 
const BatchTensorraw_value () const
 Raw flattened variable value.
 
virtual const BatchTensor tensor () const =0
 Variable value of the logical shape.
 
const VariableNamename () const
 Name of this variable.
 
TorchShapeRef batch_sizes () const
 Batch shape.
 
virtual TorchShapeRef base_sizes () const =0
 Base shape.
 
TorchSize batch_dim () const
 Batch dimension.
 
TorchSize base_dim () const
 Base dimension.
 
TorchSize base_storage () const
 Base storage.
 
virtual TorchShapeRef sizes () const =0
 Total shape.
 
const LabeledVectorvalue_storage () const
 Accessors for storage.
 
const LabeledMatrixderivative_storage () const
 
const LabeledTensor3Dsecond_derivative_storage () const
 

Protected Attributes

const VariableName _name
 Name of the variable.
 
TorchShape _batch_sizes
 Batch shape of this variable.
 
std::vector< VariableName_args
 Names of the variables that this variable depends on.
 
BatchTensor _raw_value
 The raw (flattened) variable value.
 
std::map< VariableName, BatchTensor_dvalue_d
 The derivative of this variable w.r.t. arguments.
 
std::map< VariableName, std::map< VariableName, BatchTensor > > _d2value_d
 The second derivative of this variable w.r.t. arguments.
 
const LabeledVector_value_storage
 The value storage that this variable is viewing into.
 
const LabeledMatrix_derivative_storage
 The derivative storage that this variable is viewing into.
 
const LabeledTensor3D_second_derivative_storage
 The second derivative storage that this variable is viewing into.
 

Constructor & Destructor Documentation

◆ VariableBase()

VariableBase ( const VariableName & name_in)
inline

◆ ~VariableBase()

virtual ~VariableBase ( )
virtualdefault

Member Function Documentation

◆ add_arg()

void add_arg ( const VariableBase & arg)
inline

Add an argument.

◆ args()

const std::vector< VariableName > & args ( ) const
inline

Arguments.

◆ base_dim()

TorchSize base_dim ( ) const
inline

Base dimension.

◆ base_sizes()

◆ base_storage()

TorchSize base_storage ( ) const
inline

Base storage.

◆ batch_dim()

TorchSize batch_dim ( ) const
inline

Batch dimension.

◆ batch_sizes()

TorchShapeRef batch_sizes ( ) const
inline

Batch shape.

◆ cache()

◆ clear_args()

void clear_args ( )
inline

Clear arguments.

◆ d() [1/2]

Create a wrapper representing the derivative dy/dx.

◆ d() [2/2]

Create a wrapper representing the second derivative d2y/dx2.

◆ derivative_storage()

const LabeledMatrix & derivative_storage ( ) const

◆ name()

const VariableName & name ( ) const
inline

Name of this variable.

◆ raw_value()

const BatchTensor & raw_value ( ) const
inline

Raw flattened variable value.

◆ reinit_views()

◆ requires_grad_()

◆ second_derivative_storage()

const LabeledTensor3D & second_derivative_storage ( ) const

◆ setup_views() [1/2]

void setup_views ( const LabeledVector * value,
const LabeledMatrix * deriv = nullptr,
const LabeledTensor3D * secderiv = nullptr )

Setup the variable's views into blocks of the storage.

◆ setup_views() [2/2]

void setup_views ( const VariableBase * other)

Setup the variable's views into blocks of the storage.

◆ sizes()

◆ tensor()

◆ value_storage()

const LabeledVector & value_storage ( ) const

Accessors for storage.

Member Data Documentation

◆ _args

std::vector<VariableName> _args
protected

Names of the variables that this variable depends on.

◆ _batch_sizes

TorchShape _batch_sizes
protected

Batch shape of this variable.

◆ _d2value_d

std::map<VariableName, std::map<VariableName, BatchTensor> > _d2value_d
protected

The second derivative of this variable w.r.t. arguments.

◆ _derivative_storage

const LabeledMatrix* _derivative_storage
protected

The derivative storage that this variable is viewing into.

◆ _dvalue_d

std::map<VariableName, BatchTensor> _dvalue_d
protected

The derivative of this variable w.r.t. arguments.

◆ _name

const VariableName _name
protected

Name of the variable.

◆ _raw_value

BatchTensor _raw_value
protected

The raw (flattened) variable value.

◆ _second_derivative_storage

const LabeledTensor3D* _second_derivative_storage
protected

The second derivative storage that this variable is viewing into.

◆ _value_storage

const LabeledVector* _value_storage
protected

The value storage that this variable is viewing into.