27#include "neml2/tensors/LabeledAxisAccessor.h"
28#include "neml2/tensors/tensors.h"
29#include "neml2/tensors/LabeledVector.h"
30#include "neml2/tensors/LabeledMatrix.h"
31#include "neml2/tensors/LabeledTensor3D.h"
137 std::map<VariableName, std::map<VariableName, Tensor>>
_d2value_d;
163 template <
typename T2 = T,
typename =
typename std::enable_if_t<!std::is_same_v<Tensor, T2>>>
173 template <
typename T2 = T,
typename =
typename std::enable_if_t<std::is_same_v<Tensor, T2>>>
212 [[
deprecated(
"Variable<T> must be assigned to references -- missing &")]]
void
225 _value.index_put_({torch::indexing::Slice()},
249 template <
typename T2 = T,
typename =
typename std::enable_if_t<!std::is_same_v<T2, Tensor>>>
297#define FWD_VARIABLE_BINARY_OP(op) \
298 template <typename T1, \
300 typename = typename std::enable_if_t<std::is_base_of_v<VariableBase, T1> || \
301 std::is_base_of_v<VariableBase, T2>>> \
302 auto op(const T1 & a, const T2 & b) \
304 if constexpr (std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
305 return op(a.value(), b.value()); \
307 if constexpr (std::is_base_of_v<VariableBase, T1> && !std::is_base_of_v<VariableBase, T2>) \
308 return op(a.value(), b); \
310 if constexpr (!std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
311 return op(a, b.value()); \
314FWD_VARIABLE_BINARY_OP(
operator+);
315FWD_VARIABLE_BINARY_OP(
operator-);
316FWD_VARIABLE_BINARY_OP(
operator*);
317FWD_VARIABLE_BINARY_OP(
operator/);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
Definition Variable.h:270
Derivative & operator=(const Tensor &val)
Definition Variable.cxx:140
Derivative(Tensor &val)
Definition Variable.h:272
const Tensor & value() const
Definition Variable.h:277
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:47
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
A single-batched, logically 3D LabeledTensor.
Definition LabeledTensor3D.h:38
A single-batched, logically 1D LabeledTensor.
Definition LabeledVector.h:38
The base class for all constitutive models.
Definition Model.h:55
Derivative d(const VariableBase &x)
Create a wrapper representing the derivative dy/dx.
Definition Variable.cxx:91
virtual ~VariableBase()=default
bool is_old_force() const
Definition Variable.h:109
const Model * _owner
The model which declared this variable.
Definition Variable.h:125
bool is_parameter() const
Definition Variable.h:111
bool is_state() const
Definition Variable.h:106
const Model & owner() const
The owner of this variable.
Definition Variable.h:78
const bool _is_residual
Definition Variable.h:148
const bool _is_state
Definition Variable.h:144
Size base_storage() const
Base storage.
Definition Variable.h:96
const VariableBase * src() const
The source variable.
Definition Variable.h:81
TensorShape _batch_sizes
Batch shape of this variable.
Definition Variable.h:128
Size batch_dim() const
Batch dimension.
Definition Variable.h:90
const Tensor & raw_value() const
Raw flattened variable value.
Definition Variable.h:69
TensorShapeRef batch_sizes() const
Batch shape.
Definition Variable.h:84
Tensor _raw_value
The raw (flattened) variable value.
Definition Variable.h:131
std::map< VariableName, std::map< VariableName, Tensor > > _d2value_d
The second derivative of this variable w.r.t. arguments.
Definition Variable.h:137
VariableBase(const VariableName &name_in, const Model *owner)
Definition Variable.cxx:30
const bool _is_old_force
Definition Variable.h:147
const bool _is_other
Definition Variable.h:150
bool is_solve_dependent() const
Definition Variable.h:113
const bool _is_force
Definition Variable.h:146
bool is_residual() const
Definition Variable.h:110
virtual void setup_views(const LabeledVector *value, const LabeledMatrix *deriv=nullptr, const LabeledTensor3D *secderiv=nullptr)
Setup the variable's views into blocks of the storage.
Definition Variable.cxx:53
virtual TensorShapeRef sizes() const =0
Total shape.
const bool _is_solve_dependent
Definition Variable.h:151
const VariableName & name() const
Name of this variable.
Definition Variable.h:75
const VariableBase * _src
The source variable this variable follows.
Definition Variable.h:140
const bool _is_parameter
Definition Variable.h:149
virtual const Tensor tensor() const =0
Variable value of the logical shape.
bool is_dependent() const
Check if the derivative with respect to this variable should be evaluated.
Definition Variable.cxx:85
const bool _is_old_state
Definition Variable.h:145
const VariableName _name
Name of the variable.
Definition Variable.h:122
virtual TensorShapeRef base_sizes() const =0
Base shape.
Size base_dim() const
Base dimension.
Definition Variable.h:93
bool is_force() const
Definition Variable.h:108
bool is_other() const
Definition Variable.h:112
virtual TensorType type() const =0
Variable type.
std::map< VariableName, Tensor > _dvalue_d
The derivative of this variable w.r.t. arguments.
Definition Variable.h:134
bool is_old_state() const
Definition Variable.h:107
virtual void cache(TensorShapeRef batch_shape)
Cache the variable's batch shape.
Definition Variable.cxx:47
virtual void requires_grad_(bool req=true)=0
Set requires_grad for the underlying storage.
Concrete definition of a variable.
Definition Variable.h:161
void operator=(const Tensor &val)
Set the raw value to store val.
Definition Variable.h:223
virtual void setup_views(const VariableBase *other) override
Setup the variable's views following another variable.
Definition Variable.h:193
virtual void cache(TensorShapeRef batch_shape) override
Set the batch shape and base shape according to val.
Definition Variable.h:243
virtual void requires_grad_(bool req=true) override
Set requires_grad for the underlying storage.
Definition Variable.h:199
Variable(const VariableName &name_in, const Model *owner, TensorShapeRef base_shape, TensorType type=TensorType::kTensor)
Definition Variable.h:174
const TensorType _type
Variable tensor type.
Definition Variable.h:257
virtual void setup_views(const LabeledVector *value, const LabeledMatrix *deriv=nullptr, const LabeledTensor3D *secderiv=nullptr) override
Setup the variable's views into blocks of the storage.
Definition Variable.h:184
TensorShape _sizes
Shape of this variable.
Definition Variable.h:263
void operator=(const Variable< T > &)
Suppressed assignment operator to prevent accidental dereferencing.
Definition Variable.h:213
T operator-() const
Negation.
Definition Variable.h:238
const TensorShape _base_sizes
Base shape of this variable.
Definition Variable.h:260
virtual const Tensor tensor() const override
Variable value of the logical shape.
Definition Variable.h:233
virtual TensorShapeRef base_sizes() const override
Base shape.
Definition Variable.h:201
const T & value() const
Variable value of the logical shape.
Definition Variable.h:230
Variable(const VariableName &name_in, const Model *owner, TensorType type=TensorTypeEnum< T2 >::value)
Definition Variable.h:164
virtual TensorShapeRef sizes() const override
Total shape.
Definition Variable.h:203
T _value
Variable value of the logical shape.
Definition Variable.h:266
virtual TensorType type() const override
Variable type.
Definition Variable.h:235
Variable(const Variable< T > &)
Suppressed constructor to prevent accidental dereferencing.
Definition Variable.h:206
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:40
TensorShape add_shapes(S &&... shape)
Definition utils.h:298
Definition CrossRef.cxx:30
torch::SmallVector< Size > TensorShape
Definition types.h:34
LabeledAxisAccessor VariableName
Definition parser_utils.h:33
TensorType
Definition tensors.h:57
torch::IntArrayRef TensorShapeRef
Definition types.h:35