27#include "neml2/tensors/LabeledAxisAccessor.h"
28#include "neml2/tensors/tensors.h"
29#include "neml2/tensors/LabeledVector.h"
30#include "neml2/tensors/LabeledMatrix.h"
31#include "neml2/tensors/LabeledTensor3D.h"
71 const std::vector<VariableName> &
args()
const {
return _args; }
135 std::map<VariableName, std::map<VariableName, BatchTensor>>
_d2value_d;
155 template <
typename T2 = T,
typename =
typename std::enable_if_t<!std::is_same_v<BatchTensor, T2>>>
162 template <
typename T2 = T,
typename =
typename std::enable_if_t<std::is_same_v<BatchTensor, T2>>>
189 [[
deprecated(
"Variable<T> must be assigned to references -- missing &")]]
void
202 _value.index_put_({torch::indexing::Slice()},
224 template <
typename T2 = T,
typename =
typename std::enable_if_t<!std::is_same_v<T2, BatchTensor>>>
269#define FWD_VARIABLE_BINARY_OP(op) \
270 template <typename T1, \
272 typename = typename std::enable_if_t<std::is_base_of_v<VariableBase, T1> || \
273 std::is_base_of_v<VariableBase, T2>>> \
274 auto op(const T1 & a, const T2 & b) \
276 if constexpr (std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
277 return op(a.value(), b.value()); \
279 if constexpr (std::is_base_of_v<VariableBase, T1> && !std::is_base_of_v<VariableBase, T2>) \
280 return op(a.value(), b); \
282 if constexpr (!std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
283 return op(a, b.value()); \
286FWD_VARIABLE_BINARY_OP(
operator+);
287FWD_VARIABLE_BINARY_OP(
operator-);
288FWD_VARIABLE_BINARY_OP(
operator*);
289FWD_VARIABLE_BINARY_OP(
operator/);
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
Definition Variable.h:242
void operator=(const BatchTensor &val)
Definition Variable.cxx:133
Derivative(BatchTensor &val)
Definition Variable.h:244
const BatchTensor & value() const
Definition Variable.h:249
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:44
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
A single-batched, logically 3D LabeledTensor.
Definition LabeledTensor3D.h:38
A single-batched, logically 1D LabeledTensor.
Definition LabeledVector.h:38
BatchTensor _raw_value
The raw (flattened) variable value.
Definition Variable.h:129
VariableBase(const VariableName &name_in)
Definition Variable.h:43
virtual TorchShapeRef base_sizes() const =0
Base shape.
const LabeledVector * _value_storage
The value storage that this variable is viewing into.
Definition Variable.h:138
std::map< VariableName, BatchTensor > _dvalue_d
The derivative of this variable w.r.t. arguments.
Definition Variable.h:132
Derivative d(const VariableBase &x)
Create a wrapper representing the derivative dy/dx.
Definition Variable.cxx:102
const std::vector< VariableName > & args() const
Arguments.
Definition Variable.h:71
virtual ~VariableBase()=default
TorchSize base_storage() const
Base storage.
Definition Variable.h:113
TorchShapeRef batch_sizes() const
Batch shape.
Definition Variable.h:101
const LabeledMatrix * _derivative_storage
The derivative storage that this variable is viewing into.
Definition Variable.h:141
const LabeledTensor3D & second_derivative_storage() const
Definition Variable.cxx:95
const LabeledVector & value_storage() const
Accessors for storage.
Definition Variable.cxx:81
std::vector< VariableName > _args
Names of the variables that this variable depends on.
Definition Variable.h:126
TorchSize batch_dim() const
Batch dimension.
Definition Variable.h:107
void add_arg(const VariableBase &arg)
Add an argument.
Definition Variable.h:74
virtual TorchShapeRef sizes() const =0
Total shape.
TorchShape _batch_sizes
Batch shape of this variable.
Definition Variable.h:123
virtual const BatchTensor tensor() const =0
Variable value of the logical shape.
void setup_views(const LabeledVector *value, const LabeledMatrix *deriv=nullptr, const LabeledTensor3D *secderiv=nullptr)
Setup the variable's views into blocks of the storage.
Definition Variable.cxx:36
const LabeledTensor3D * _second_derivative_storage
The second derivative storage that this variable is viewing into.
Definition Variable.h:144
const LabeledMatrix & derivative_storage() const
Definition Variable.cxx:88
const VariableName & name() const
Name of this variable.
Definition Variable.h:98
TorchSize base_dim() const
Base dimension.
Definition Variable.h:110
const VariableName _name
Name of the variable.
Definition Variable.h:120
const BatchTensor & raw_value() const
Raw flattened variable value.
Definition Variable.h:92
void clear_args()
Clear arguments.
Definition Variable.h:77
virtual void cache(TorchShapeRef batch_shape)
Cache the variable's batch shape.
Definition Variable.cxx:30
virtual void requires_grad_(bool req=true)=0
Set requires_grad for the underlying storage.
virtual void reinit_views(bool out, bool dout_din, bool d2out_din2)
Reinitialize variable views.
Definition Variable.cxx:58
std::map< VariableName, std::map< VariableName, BatchTensor > > _d2value_d
The second derivative of this variable w.r.t. arguments.
Definition Variable.h:135
Concrete definition of a variable.
Definition Variable.h:153
const TorchShape _base_sizes
Base shape of this variable.
Definition Variable.h:232
virtual void requires_grad_(bool req=true) override
Set requires_grad for the underlying storage.
Definition Variable.h:176
virtual const BatchTensor tensor() const override
Variable value of the logical shape.
Definition Variable.h:210
void operator=(const BatchTensor &val)
Set the raw value to store val.
Definition Variable.h:200
virtual void reinit_views(bool out, bool dout_din, bool d2out_din2) override
Reinitialize variable views.
Definition Variable.h:169
void operator=(const Variable< T > &)
Suppressed assignment operator to prevent accidental dereferencing.
Definition Variable.h:190
virtual void cache(TorchShapeRef batch_shape) override
Set the batch shape and base shape according to val.
Definition Variable.h:218
Variable(const VariableName &name_in, TorchShapeRef base_shape)
Definition Variable.h:163
T operator-() const
Negation.
Definition Variable.h:213
virtual TorchShapeRef sizes() const override
Total shape.
Definition Variable.h:180
const T & value() const
Variable value of the logical shape.
Definition Variable.h:207
virtual TorchShapeRef base_sizes() const override
Base shape.
Definition Variable.h:178
Variable(const VariableName &name_in)
Definition Variable.h:156
T _value
Variable value of the logical shape.
Definition Variable.h:238
TorchShape _sizes
Shape of this variable.
Definition Variable.h:235
Variable(const Variable< T > &)
Suppressed constructor to prevent accidental dereferencing.
Definition Variable.h:183
TorchSize storage_size(TorchShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:32
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
std::vector< TorchSize > TorchShape
Definition types.h:36
torch::IntArrayRef TorchShapeRef
Definition types.h:37