25#include "neml2/tensors/Variable.h"
105 "Error retrieving first derivative: ",
107 " does not depend on ",
117 "Error retrieving second derivative: ",
119 " does not depend on ",
122 "Error retrieving second derivative: d(",
126 ") does not depend on ",
135 _value.index_put_({torch::indexing::Slice()},
136 val.batch_expand_as(_value).base_reshape(_value.
base_sizes()));
TorchShapeRef base_sizes() const
Return the base size.
Definition BatchTensorBase.cxx:163
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
Definition Variable.h:242
void operator=(const BatchTensor &val)
Definition Variable.cxx:133
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
A single-batched, logically 3D LabeledTensor.
Definition LabeledTensor3D.h:38
A single-batched, logically 1D LabeledTensor.
Definition LabeledVector.h:38
BatchTensor _raw_value
The raw (flattened) variable value.
Definition Variable.h:129
const LabeledVector * _value_storage
The value storage that this variable is viewing into.
Definition Variable.h:138
std::map< VariableName, BatchTensor > _dvalue_d
The derivative of this variable w.r.t. arguments.
Definition Variable.h:132
Derivative d(const VariableBase &x)
Create a wrapper representing the derivative dy/dx.
Definition Variable.cxx:102
const std::vector< VariableName > & args() const
Arguments.
Definition Variable.h:71
const LabeledMatrix * _derivative_storage
The derivative storage that this variable is viewing into.
Definition Variable.h:141
const LabeledTensor3D & second_derivative_storage() const
Definition Variable.cxx:95
const LabeledVector & value_storage() const
Accessors for storage.
Definition Variable.cxx:81
TorchShape _batch_sizes
Batch shape of this variable.
Definition Variable.h:123
void setup_views(const LabeledVector *value, const LabeledMatrix *deriv=nullptr, const LabeledTensor3D *secderiv=nullptr)
Setup the variable's views into blocks of the storage.
Definition Variable.cxx:36
const LabeledTensor3D * _second_derivative_storage
The second derivative storage that this variable is viewing into.
Definition Variable.h:144
const LabeledMatrix & derivative_storage() const
Definition Variable.cxx:88
const VariableName & name() const
Name of this variable.
Definition Variable.h:98
virtual void cache(TorchShapeRef batch_shape)
Cache the variable's batch shape.
Definition Variable.cxx:30
virtual void reinit_views(bool out, bool dout_din, bool d2out_din2)
Reinitialize variable views.
Definition Variable.cxx:58
std::map< VariableName, std::map< VariableName, BatchTensor > > _d2value_d
The second derivative of this variable w.r.t. arguments.
Definition Variable.h:135
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
torch::IntArrayRef TorchShapeRef
Definition types.h:37
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73