25#include "neml2/tensors/LabeledMatrix.h"
26#include "neml2/tensors/LabeledVector.h"
27#include "neml2/misc/math.h"
29using namespace torch::indexing;
36 const torch::TensorOptions & options)
68 "LabeledMatrix batch sizes are not the same");
76LabeledMatrix::inverse()
const
79 "Can only invert square derivatives");
81 return LabeledMatrix(math::linalg::inv(tensor()), {&axis(1), &axis(0)});
BatchTensor base_index(const TorchSlice &indices) const
Return an index sliced on the base dimensions.
Definition BatchTensorBase.cxx:193
static BatchTensor identity(TorchSize n, const torch::TensorOptions &options=default_tensor_options())
Unbatched identity tensor.
Definition BatchTensor.cxx:91
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
static LabeledMatrix identity(TorchShapeRef batch_size, const LabeledAxis &axis, const torch::TensorOptions &options=default_tensor_options())
Create a labeled identity tensor.
Definition LabeledMatrix.cxx:34
void accumulate(const LabeledMatrix &other, bool recursive=true)
Definition LabeledMatrix.cxx:43
const LabeledAxis & axis(TorchSize i=0) const
Get a specific labeled axis.
Definition LabeledTensor.h:130
torch::TensorOptions options() const
Get the tensor options.
Definition LabeledTensor.h:112
BatchTensor _tensor
The tensor.
Definition LabeledTensor.h:215
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
torch::IntArrayRef TorchShapeRef
Definition types.h:37