template<
typename T>
class neml2::Variable< T >
Concrete definition of a variable.
|
template<typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<BatchTensor, T2>>> |
| Variable (const VariableName &name_in) |
|
template<typename T2 = T, typename = typename std::enable_if_t<std::is_same_v<BatchTensor, T2>>> |
| Variable (const VariableName &name_in, TorchShapeRef base_shape) |
|
virtual void | reinit_views (bool out, bool dout_din, bool d2out_din2) override |
| Reinitialize variable views.
|
|
virtual void | requires_grad_ (bool req=true) override |
| Set requires_grad for the underlying storage.
|
|
virtual TorchShapeRef | base_sizes () const override |
| Base shape.
|
|
virtual TorchShapeRef | sizes () const override |
| Total shape.
|
|
| Variable (const Variable< T > &) |
| Suppressed constructor to prevent accidental dereferencing.
|
|
void | operator= (const Variable< T > &) |
| Suppressed assignment operator to prevent accidental dereferencing.
|
|
void | operator= (const BatchTensor &val) |
| Set the raw value to store val .
|
|
const T & | value () const |
| Variable value of the logical shape.
|
|
virtual const BatchTensor | tensor () const override |
| Variable value of the logical shape.
|
|
T | operator- () const |
| Negation.
|
|
| operator T () const |
|
virtual void | cache (TorchShapeRef batch_shape) override |
| Set the batch shape and base shape according to val .
|
|
template<typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<T2, BatchTensor>>> |
| operator BatchTensor () const |
|
| VariableBase (const VariableName &name_in) |
|
virtual | ~VariableBase ()=default |
|
void | setup_views (const LabeledVector *value, const LabeledMatrix *deriv=nullptr, const LabeledTensor3D *secderiv=nullptr) |
| Setup the variable's views into blocks of the storage.
|
|
void | setup_views (const VariableBase *other) |
| Setup the variable's views into blocks of the storage.
|
|
const std::vector< VariableName > & | args () const |
| Arguments.
|
|
void | add_arg (const VariableBase &arg) |
| Add an argument.
|
|
void | clear_args () |
| Clear arguments.
|
|
Derivative | d (const VariableBase &x) |
| Create a wrapper representing the derivative dy/dx.
|
|
Derivative | d (const VariableBase &x1, const VariableBase &x2) |
| Create a wrapper representing the second derivative d2y/dx2.
|
|
const BatchTensor & | raw_value () const |
| Raw flattened variable value.
|
|
const VariableName & | name () const |
| Name of this variable.
|
|
TorchShapeRef | batch_sizes () const |
| Batch shape.
|
|
TorchSize | batch_dim () const |
| Batch dimension.
|
|
TorchSize | base_dim () const |
| Base dimension.
|
|
TorchSize | base_storage () const |
| Base storage.
|
|
const LabeledVector & | value_storage () const |
| Accessors for storage.
|
|
const LabeledMatrix & | derivative_storage () const |
|
const LabeledTensor3D & | second_derivative_storage () const |
|