25#include "neml2/tensors/LabeledTensor3D.h"
26#include "neml2/tensors/LabeledMatrix.h"
28using namespace torch::indexing;
45 neml_assert_dbg(axis(1) ==
other.axis(1),
"Can only accumulate 3D tensors with conformal y axes");
46 neml_assert_dbg(axis(2) ==
other.axis(2),
"Can only accumulate 3D tensors with conformal z axes");
BatchTensor base_index(const TorchSlice &indices) const
Return an index sliced on the base dimensions.
Definition BatchTensorBase.cxx:193
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
A single-batched, logically 3D LabeledTensor.
Definition LabeledTensor3D.h:38
void accumulate(const LabeledTensor3D &other, bool recursive=true)
Definition LabeledTensor3D.cxx:33
const LabeledAxis & axis(TorchSize i=0) const
Get a specific labeled axis.
Definition LabeledTensor.h:130
BatchTensor _tensor
The tensor.
Definition LabeledTensor.h:215
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
TorchSize broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.