27#include "neml2/misc/types.h"
28#include "neml2/tensors/LabeledAxis.h"
29#include "neml2/tensors/Tensor.h"
42template <
class Derived, Size D>
64 operator torch::Tensor()
const;
70 const std::array<const LabeledAxis *, D> &
axes,
76 const std::array<const LabeledAxis *, D> &
axes,
112 torch::TensorOptions
options()
const;
116 torch::Device
device()
const;
155 template <
typename T,
typename = std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
159 const std::array<const LabeledAxis *, D> &
axes()
const {
return _axes; }
168 std::array<const LabeledAxis *, D>
_axes;
178template <
class Derived, Size D>
179template <
typename T,
typename>
183 return base_index(indices).base_reshape(T::const_base_sizes);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
The primary data structure in NEML2 for working with labeled tensor views.
Definition LabeledTensor.h:44
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Definition LabeledTensor.cxx:127
LabeledTensor()=default
Default constructor.
void zero_()
Set all entries to zero.
Definition LabeledTensor.cxx:162
bool requires_grad() const
Get the requires_grad property.
Definition LabeledTensor.cxx:169
void requires_grad_(bool req=true)
Set the requires_grad property.
Definition LabeledTensor.cxx:176
Size batch_size(Size d) const
Return the length of some batch axis.
Definition LabeledTensor.cxx:253
Derived batch_index(indexing::TensorIndicesRef indices) const
Definition LabeledTensor.cxx:281
const Tensor & tensor() const
Definition LabeledTensor.h:81
void base_index_put_(indexing::TensorLabelsRef labels, const Tensor &other)
Set values by slicing on the base dimensions.
Definition LabeledTensor.cxx:303
torch::Dtype scalar_type() const
Tensor options.
Definition LabeledTensor.cxx:197
Tensor base_index(indexing::TensorLabelsRef labels) const
Get a tensor by slicing on the base dimensions.
Definition LabeledTensor.cxx:288
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition LabeledTensor.cxx:274
bool batched() const
Whether the tensor is batched.
Definition LabeledTensor.cxx:232
Size batch_dim() const
Return the number of batch dimensions.
Definition LabeledTensor.cxx:239
Derived detach() const
Return a copy without gradient graphs.
Definition LabeledTensor.cxx:134
Size dim() const
Number of tensor dimensions.
Definition LabeledTensor.cxx:211
TensorShapeRef sizes() const
Tensor shape.
Definition LabeledTensor.cxx:218
T reinterpret(indexing::TensorLabelsRef indices) const
Get a tensor by slicing on the base dimensions AND reinterpret it as a primitive tensor.
Definition LabeledTensor.h:181
operator torch::Tensor() const
Definition LabeledTensor.cxx:90
Derived operator-() const
Negation.
Definition LabeledTensor.cxx:183
Size base_size(Size d) const
Return the length of some base axis.
Definition LabeledTensor.cxx:267
TensorShapeRef batch_sizes() const
Return the batch size.
Definition LabeledTensor.cxx:246
static Derived zeros(TensorShapeRef batch_shape, const std::array< const LabeledAxis *, D > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new storage with zeros.
Definition LabeledTensor.cxx:112
torch::Device device() const
Tensor options.
Definition LabeledTensor.cxx:204
torch::TensorOptions options() const
Definition LabeledTensor.cxx:190
Derived to(const torch::TensorOptions &options) const
Change tensor options.
Definition LabeledTensor.cxx:148
const LabeledAxis & axis(Size i=0) const
Get a specific labeled axis.
Definition LabeledTensor.h:161
static Derived empty(TensorShapeRef batch_shape, const std::array< const LabeledAxis *, D > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new empty storage.
Definition LabeledTensor.cxx:97
constexpr Size base_dim() const
Return the number of base dimensions.
Definition LabeledTensor.h:128
void copy_(const torch::Tensor &other)
Copy another tensor.
Definition LabeledTensor.cxx:155
Tensor _tensor
The tensor.
Definition LabeledTensor.h:165
void detach_()
Detach from gradient graphs.
Definition LabeledTensor.cxx:141
void batch_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Set values by slicing on the batch dimensions.
Definition LabeledTensor.cxx:295
Size size(Size dim) const
Tensor shape.
Definition LabeledTensor.cxx:225
TensorShapeRef base_sizes() const
Return the base size.
Definition LabeledTensor.cxx:260
Tensor & tensor()
Definition LabeledTensor.h:82
std::array< const LabeledAxis *, D > _axes
The labeled axes of this tensor.
Definition LabeledTensor.h:168
LabeledTensor< Derived, D > & operator=(const Derived &other)
Assignment operator.
Definition LabeledTensor.cxx:76
const std::array< const LabeledAxis *, D > & axes() const
Get all the labeled axes.
Definition LabeledTensor.h:159
c10::ArrayRef< LabeledAxisAccessor > TensorLabelsRef
Definition LabeledAxisAccessor.h:158
torch::SmallVector< TensorIndex > TensorIndices
Definition types.h:41
torch::ArrayRef< TensorIndex > TensorIndicesRef
Definition types.h:42
Definition CrossRef.cxx:30
torch::TensorOptions & default_tensor_options()
Definition types.cxx:30
torch::SmallVector< Size > TensorShape
Definition types.h:34
int64_t Size
Definition types.h:33
torch::IntArrayRef TensorShapeRef
Definition types.h:35