27#include "neml2/tensors/LabeledTensor.h"
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
LabeledMatrix chain(const LabeledMatrix &other) const
Chain rule product of two derivatives.
Definition LabeledMatrix.cxx:49
static LabeledMatrix identity(TensorShapeRef batch_size, const LabeledAxis &axis, const torch::TensorOptions &options=default_tensor_options())
Create a labeled identity tensor.
Definition LabeledMatrix.cxx:32
void fill(const LabeledMatrix &other, bool recursive=true)
Definition LabeledMatrix.cxx:40
The primary data structure in NEML2 for working with labeled tensor views.
Definition LabeledTensor.h:44
LabeledTensor()=default
Default constructor.
Size batch_size(Size d) const
Return the length of some batch axis.
Definition LabeledTensor.cxx:253
torch::TensorOptions options() const
Definition LabeledTensor.cxx:190
const LabeledAxis & axis(Size i=0) const
Get a specific labeled axis.
Definition LabeledTensor.h:161
Definition CrossRef.cxx:30
torch::TensorOptions & default_tensor_options()
Definition types.cxx:30
torch::IntArrayRef TensorShapeRef
Definition types.h:35