27#include "neml2/misc/types.h"
28#include "neml2/tensors/LabeledAxis.h"
29#include "neml2/tensors/BatchTensor.h"
42template <
class Derived, TorchSize D>
52 const std::vector<const LabeledAxis *> &
axes);
69 operator torch::Tensor()
const;
74 const std::vector<const LabeledAxis *> &
axes,
83 const std::vector<const LabeledAxis *> &
axes,
127 const std::vector<const LabeledAxis *> &
axes()
const {
return _axes; }
133 template <
typename...
S>
140 template <
typename...
S>
145 template <
typename...
S>
152 template <
typename...
S>
168 template <
typename T>
175 template <
typename T,
typename...
S>
183 template <
typename T,
typename...
S>
186 return T(((*
this)(
names...))
188 this->batch_dim() +
sizeof...(names));
192 template <
typename T,
typename...
S>
195 (*this)(
names...).index_put_(
196 {torch::indexing::None},
201 template <
typename T,
typename...
S>
219 std::vector<const LabeledAxis *>
_axes;
222 template <std::size_t...
I,
typename...
S>
223 TorchSlice slice_indices_impl(std::index_sequence<I...>,
S &&...
names)
const;
225 template <std::size_t...
I,
typename...
S>
226 TorchShape storage_size_impl(std::index_sequence<I...>,
S &&...
names)
const;
228 template <std::size_t...
I,
typename...
S>
229 Derived block_impl(std::index_sequence<I...>,
S &&...
names)
const;
232template <
class Derived, TorchSize D>
237 _tensor.copy_(
other);
240template <
class Derived, TorchSize D>
241template <
typename...
S>
245 static_assert(
sizeof...(names) ==
D,
"Wrong labaled dimesion in LabeledTensor::slice_indices");
246 return slice_indices_impl(std::make_index_sequence<
sizeof...(
names)>(),
247 std::forward<S>(
names)...);
250template <
class Derived, TorchSize D>
251template <std::size_t...
I,
typename...
S>
255 return {_axes[
I]->indices(
names)...};
258template <
class Derived, TorchSize D>
259template <
typename... S>
263 static_assert(
sizeof...(names) ==
D,
"Wrong labaled dimesion in LabeledTensor::storage_size");
264 return storage_size_impl(std::make_index_sequence<D>(), std::forward<S>(
names)...);
267template <
class Derived, TorchSize D>
268template <std::size_t...
I,
typename...
S>
275template <
class Derived, TorchSize D>
276template <
typename... S>
280 static_assert(
sizeof...(names) ==
D,
"Wrong labeled dimension in LabeledTensor::operator()");
281 return base_index(slice_indices(
names...));
284template <
class Derived, TorchSize D>
285template <
typename...
S>
289 return block_impl(std::make_index_sequence<
sizeof...(
names)>(), std::forward<S>(
names)...);
292template <
class Derived, TorchSize D>
293template <std::size_t...
I,
typename...
S>
298 std::vector<const LabeledAxis *>
new_axes = {&_axes[
I]->subaxis(
names)...};
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
The primary data structure in NEML2 for working with labeled tensor views.
Definition LabeledTensor.h:44
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Clone this LabeledTensor.
Definition LabeledTensor.cxx:131
const LabeledAxis & axis(TorchSize i=0) const
Get a specific labeled axis.
Definition LabeledTensor.h:130
void operator=(const Derived &other)
Assignment operator.
Definition LabeledTensor.cxx:67
TorchSlice slice_indices(S &&... names) const
How to slice the tensor given the names on each axis.
Definition LabeledTensor.h:243
LabeledTensor()=default
Default constructor.
void zero_()
Zero out this tensor.
Definition LabeledTensor.cxx:152
Derived slice(TorchSize i, const std::string &name) const
Slice the tensor on the given dimension by a single variable or sub-axis.
Definition LabeledTensor.cxx:194
BatchTensor operator()(S &&... names) const
Definition LabeledTensor.h:278
TorchShapeRef batch_sizes() const
Return the batch size.
Definition LabeledTensor.cxx:173
Derived detach() const
Return a copy without gradient graphs.
Definition LabeledTensor.cxx:138
static Derived zeros(TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new storage with zeros.
Definition LabeledTensor.cxx:109
Derived operator-() const
Negation.
Definition LabeledTensor.cxx:235
TorchSize batch_dim() const
Return the number of batch dimensions.
Definition LabeledTensor.cxx:159
Derived batch_index(TorchSlice indices) const
Get a batch.
Definition LabeledTensor.cxx:207
Derived block(S &&... names) const
Get the sub-block labeled by the given sub-axis names.
Definition LabeledTensor.h:287
torch::TensorOptions options() const
Get the tensor options.
Definition LabeledTensor.h:112
TorchShapeRef base_sizes() const
Return the base size.
Definition LabeledTensor.cxx:180
BatchTensor base_index(TorchSlice indices) const
Return an index sliced on the batch dimensions.
Definition LabeledTensor.cxx:221
Derived to(const torch::TensorOptions &options) const
Change tensor options.
Definition LabeledTensor.cxx:242
void base_index_put(TorchSlice indices, const torch::Tensor &other)
Set a index sliced on the batch dimensions to a value.
Definition LabeledTensor.cxx:228
std::vector< const LabeledAxis * > _axes
The labeled axes of this tensor.
Definition LabeledTensor.h:219
static Derived zeros_like(const Derived &other)
Setup new storage with zeros like another LabeledTensor.
Definition LabeledTensor.cxx:124
const BatchTensor & tensor() const
Definition LabeledTensor.h:107
BatchTensor _tensor
The tensor.
Definition LabeledTensor.h:215
void set(const BatchTensorBase< T > &value, S &&... names)
Set and interpret the input as an object.
Definition LabeledTensor.h:193
static Derived empty(TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new empty storage.
Definition LabeledTensor.cxx:87
static Derived empty_like(const Derived &other)
Setup new empty storage like another LabeledTensor.
Definition LabeledTensor.cxx:102
TorchSize base_dim() const
Return the number of base dimensions.
Definition LabeledTensor.cxx:166
void detach_()
Detach from gradient graphs.
Definition LabeledTensor.cxx:145
TorchShapeRef storage_size() const
The shape of the entire LabeledTensor.
Definition LabeledTensor.cxx:187
variable_type< T >::type get_list(S &&... names) const
Get and interpret the view as a list of objects.
Definition LabeledTensor.h:184
void set_list(const BatchTensorBase< T > &value, S &&... names)
Set and interpret the input as a list of objects.
Definition LabeledTensor.h:202
void batch_index_put(TorchSlice indices, const torch::Tensor &other)
Set a index sliced on the batch dimensions to a value.
Definition LabeledTensor.cxx:214
const std::vector< const LabeledAxis * > & axes() const
Get all the labeled axes.
Definition LabeledTensor.h:127
BatchTensor & tensor()
Definition LabeledTensor.h:108
TorchShape storage_size(S &&... names) const
The shape of a sub-block specified by the names on each dimension.
Definition LabeledTensor.h:261
variable_type< T >::type get(S &&... names) const
Get and interpret the view as an object.
Definition LabeledTensor.h:176
void copy_(const T &other)
Copy the value from another tensor.
Definition LabeledTensor.h:235
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
const torch::TensorOptions default_tensor_options()
Definition types.cxx:30
int64_t TorchSize
Definition types.h:35
std::vector< TorchSize > TorchShape
Definition types.h:36
torch::IntArrayRef TorchShapeRef
Definition types.h:37
std::vector< at::indexing::TensorIndex > TorchSlice
Definition types.h:39
Template setup for appropriate variable types.
Definition LabeledTensor.h:170
T type
Definition LabeledTensor.h:171