25#include "neml2/tensors/LabeledTensor.h"
27#include "neml2/tensors/LabeledVector.h"
28#include "neml2/tensors/LabeledMatrix.h"
29#include "neml2/tensors/LabeledTensor3D.h"
33template <
class Derived, TorchSize D>
36 const std::vector<const LabeledAxis *> & axes)
37 : _tensor(tensor, batch_dim),
46template <
class Derived, TorchSize D>
48 const std::vector<const LabeledAxis *> & axes)
55 neml_assert_dbg(base_sizes() == storage_size(),
"LabeledTensor does not have the right size");
58template <
class Derived, TorchSize D>
65template <
class Derived, TorchSize D>
73template <
class Derived, TorchSize D>
79template <
class Derived, TorchSize D>
85template <
class Derived, TorchSize D>
88 const std::vector<const LabeledAxis *> & axes,
89 const torch::TensorOptions & options)
92 s.reserve(axes.size());
93 std::transform(axes.begin(),
95 std::back_inserter(
s),
96 [](
const LabeledAxis * axis) { return axis->storage_size(); });
100template <
class Derived, TorchSize D>
107template <
class Derived, TorchSize D>
110 const std::vector<const LabeledAxis *> & axes,
111 const torch::TensorOptions & options)
114 s.reserve(axes.size());
115 std::transform(axes.begin(),
117 std::back_inserter(
s),
122template <
class Derived, TorchSize D>
129template <
class Derived, TorchSize D>
136template <
class Derived, TorchSize D>
140 return Derived(_tensor.detach(), _axes);
143template <
class Derived, TorchSize D>
150template <
class Derived, TorchSize D>
157template <
class Derived, TorchSize D>
161 return _tensor.batch_dim();
164template <
class Derived, TorchSize D>
171template <
class Derived, TorchSize D>
175 return _tensor.batch_sizes();
178template <
class Derived, TorchSize D>
182 return _tensor.base_sizes();
185template <
class Derived, TorchSize D>
192template <
class Derived, TorchSize D>
197 idx[
i] = _axes[
i]->indices(name);
205template <
class Derived, TorchSize D>
209 return Derived(_tensor.batch_index(indices), _axes);
212template <
class Derived, TorchSize D>
216 _tensor.batch_index_put(indices,
other);
219template <
class Derived, TorchSize D>
226template <
class Derived, TorchSize D>
230 _tensor.base_index_put(indices,
other);
233template <
class Derived, TorchSize D>
237 return Derived(-_tensor, _axes);
240template <
class Derived, TorchSize D>
244 return Derived(_tensor.to(options), _axes);
BatchTensor base_index(const TorchSlice &indices) const
Return an index sliced on the base dimensions.
Definition BatchTensorBase.cxx:193
static BatchTensor zeros_like(const BatchTensor &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:59
static BatchTensor empty_like(const BatchTensor &other)
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:52
Definition BatchTensor.h:32
static BatchTensor zeros(const TorchShapeRef &base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with zeros given base shape.
Definition BatchTensor.cxx:45
static BatchTensor empty(const TorchShapeRef &base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition BatchTensor.cxx:30
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
The primary data structure in NEML2 for working with labeled tensor views.
Definition LabeledTensor.h:44
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Clone this LabeledTensor.
Definition LabeledTensor.cxx:131
void operator=(const Derived &other)
Assignment operator.
Definition LabeledTensor.cxx:67
LabeledTensor()=default
Default constructor.
void zero_()
Zero out this tensor.
Definition LabeledTensor.cxx:152
Derived slice(TorchSize i, const std::string &name) const
Slice the tensor on the given dimension by a single variable or sub-axis.
Definition LabeledTensor.cxx:194
TorchShapeRef batch_sizes() const
Return the batch size.
Definition LabeledTensor.cxx:173
Derived detach() const
Return a copy without gradient graphs.
Definition LabeledTensor.cxx:138
static Derived zeros(TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new storage with zeros.
Definition LabeledTensor.cxx:109
Derived operator-() const
Negation.
Definition LabeledTensor.cxx:235
TorchSize batch_dim() const
Return the number of batch dimensions.
Definition LabeledTensor.cxx:159
Derived batch_index(TorchSlice indices) const
Get a batch.
Definition LabeledTensor.cxx:207
TorchShapeRef base_sizes() const
Return the base size.
Definition LabeledTensor.cxx:180
BatchTensor base_index(TorchSlice indices) const
Return an index sliced on the batch dimensions.
Definition LabeledTensor.cxx:221
Derived to(const torch::TensorOptions &options) const
Change tensor options.
Definition LabeledTensor.cxx:242
void base_index_put(TorchSlice indices, const torch::Tensor &other)
Set a index sliced on the batch dimensions to a value.
Definition LabeledTensor.cxx:228
static Derived zeros_like(const Derived &other)
Setup new storage with zeros like another LabeledTensor.
Definition LabeledTensor.cxx:124
static Derived empty(TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new empty storage.
Definition LabeledTensor.cxx:87
static Derived empty_like(const Derived &other)
Setup new empty storage like another LabeledTensor.
Definition LabeledTensor.cxx:102
TorchSize base_dim() const
Return the number of base dimensions.
Definition LabeledTensor.cxx:166
void detach_()
Detach from gradient graphs.
Definition LabeledTensor.cxx:145
TorchShapeRef storage_size() const
The shape of the entire LabeledTensor.
Definition LabeledTensor.cxx:187
void batch_index_put(TorchSlice indices, const torch::Tensor &other)
Set a index sliced on the batch dimensions to a value.
Definition LabeledTensor.cxx:214
const std::vector< const LabeledAxis * > & axes() const
Get all the labeled axes.
Definition LabeledTensor.h:127
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
int64_t TorchSize
Definition types.h:35
std::vector< TorchSize > TorchShape
Definition types.h:36
torch::IntArrayRef TorchShapeRef
Definition types.h:37
std::vector< at::indexing::TensorIndex > TorchSlice
Definition types.h:39