NEML2 1.4.0
Loading...
Searching...
No Matches
Variable< T > Class Template Reference

Concrete definition of a variable. More...

Detailed Description

template<typename T>
class neml2::Variable< T >

Concrete definition of a variable.

#include <Variable.h>

Inheritance diagram for Variable< T >:

Public Member Functions

template<typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<BatchTensor, T2>>>
 Variable (const VariableName &name_in)
 
template<typename T2 = T, typename = typename std::enable_if_t<std::is_same_v<BatchTensor, T2>>>
 Variable (const VariableName &name_in, TorchShapeRef base_shape)
 
virtual void reinit_views (bool out, bool dout_din, bool d2out_din2) override
 Reinitialize variable views.
 
virtual void requires_grad_ (bool req=true) override
 Set requires_grad for the underlying storage.
 
virtual TorchShapeRef base_sizes () const override
 Base shape.
 
virtual TorchShapeRef sizes () const override
 Total shape.
 
 Variable (const Variable< T > &)
 Suppressed constructor to prevent accidental dereferencing.
 
void operator= (const Variable< T > &)
 Suppressed assignment operator to prevent accidental dereferencing.
 
void operator= (const BatchTensor &val)
 Set the raw value to store val.
 
const T & value () const
 Variable value of the logical shape.
 
virtual const BatchTensor tensor () const override
 Variable value of the logical shape.
 
operator- () const
 Negation.
 
 operator T () const
 
virtual void cache (TorchShapeRef batch_shape) override
 Set the batch shape and base shape according to val.
 
template<typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<T2, BatchTensor>>>
 operator BatchTensor () const
 
- Public Member Functions inherited from VariableBase
 VariableBase (const VariableName &name_in)
 
virtual ~VariableBase ()=default
 
void setup_views (const LabeledVector *value, const LabeledMatrix *deriv=nullptr, const LabeledTensor3D *secderiv=nullptr)
 Setup the variable's views into blocks of the storage.
 
void setup_views (const VariableBase *other)
 Setup the variable's views into blocks of the storage.
 
const std::vector< VariableName > & args () const
 Arguments.
 
void add_arg (const VariableBase &arg)
 Add an argument.
 
void clear_args ()
 Clear arguments.
 
Derivative d (const VariableBase &x)
 Create a wrapper representing the derivative dy/dx.
 
Derivative d (const VariableBase &x1, const VariableBase &x2)
 Create a wrapper representing the second derivative d2y/dx2.
 
const BatchTensorraw_value () const
 Raw flattened variable value.
 
const VariableNamename () const
 Name of this variable.
 
TorchShapeRef batch_sizes () const
 Batch shape.
 
TorchSize batch_dim () const
 Batch dimension.
 
TorchSize base_dim () const
 Base dimension.
 
TorchSize base_storage () const
 Base storage.
 
const LabeledVectorvalue_storage () const
 Accessors for storage.
 
const LabeledMatrixderivative_storage () const
 
const LabeledTensor3Dsecond_derivative_storage () const
 

Protected Attributes

const TorchShape _base_sizes
 Base shape of this variable.
 
TorchShape _sizes
 Shape of this variable.
 
_value
 Variable value of the logical shape.
 
- Protected Attributes inherited from VariableBase
const VariableName _name
 Name of the variable.
 
TorchShape _batch_sizes
 Batch shape of this variable.
 
std::vector< VariableName_args
 Names of the variables that this variable depends on.
 
BatchTensor _raw_value
 The raw (flattened) variable value.
 
std::map< VariableName, BatchTensor_dvalue_d
 The derivative of this variable w.r.t. arguments.
 
std::map< VariableName, std::map< VariableName, BatchTensor > > _d2value_d
 The second derivative of this variable w.r.t. arguments.
 
const LabeledVector_value_storage
 The value storage that this variable is viewing into.
 
const LabeledMatrix_derivative_storage
 The derivative storage that this variable is viewing into.
 
const LabeledTensor3D_second_derivative_storage
 The second derivative storage that this variable is viewing into.
 

Constructor & Destructor Documentation

◆ Variable() [1/3]

template<typename T >
template<typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<BatchTensor, T2>>>
Variable ( const VariableName & name_in)
inline

◆ Variable() [2/3]

template<typename T >
template<typename T2 = T, typename = typename std::enable_if_t<std::is_same_v<BatchTensor, T2>>>
Variable ( const VariableName & name_in,
TorchShapeRef base_shape )
inline

◆ Variable() [3/3]

template<typename T >
Variable ( const Variable< T > & )
inline

Suppressed constructor to prevent accidental dereferencing.

Member Function Documentation

◆ base_sizes()

template<typename T >
virtual TorchShapeRef base_sizes ( ) const
inlineoverridevirtual

Base shape.

Implements VariableBase.

◆ cache()

template<typename T >
virtual void cache ( TorchShapeRef batch_shape)
inlineoverridevirtual

Set the batch shape and base shape according to val.

Reimplemented from VariableBase.

◆ operator BatchTensor()

template<typename T >
template<typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<T2, BatchTensor>>>
operator BatchTensor ( ) const
inline

◆ operator T()

template<typename T >
operator T ( ) const
inline

◆ operator-()

template<typename T >
T operator- ( ) const
inline

Negation.

◆ operator=() [1/2]

template<typename T >
void operator= ( const BatchTensor & val)
inline

Set the raw value to store val.

Note that this is an in-place operation, and so we must reshape (flatten base dimensions of) val and modify raw_value.

◆ operator=() [2/2]

template<typename T >
void operator= ( const Variable< T > & )
inline

Suppressed assignment operator to prevent accidental dereferencing.

◆ reinit_views()

template<typename T >
virtual void reinit_views ( bool out,
bool dout_din,
bool d2out_din2 )
inlineoverridevirtual

Reinitialize variable views.

Reimplemented from VariableBase.

◆ requires_grad_()

template<typename T >
virtual void requires_grad_ ( bool req = true)
inlineoverridevirtual

Set requires_grad for the underlying storage.

Implements VariableBase.

◆ sizes()

template<typename T >
virtual TorchShapeRef sizes ( ) const
inlineoverridevirtual

Total shape.

Implements VariableBase.

◆ tensor()

template<typename T >
virtual const BatchTensor tensor ( ) const
inlineoverridevirtual

Variable value of the logical shape.

Implements VariableBase.

◆ value()

template<typename T >
const T & value ( ) const
inline

Variable value of the logical shape.

Member Data Documentation

◆ _base_sizes

template<typename T >
const TorchShape _base_sizes
protected

Base shape of this variable.

◆ _sizes

template<typename T >
TorchShape _sizes
protected

Shape of this variable.

◆ _value

template<typename T >
T _value
protected

Variable value of the logical shape.