NEML2 1.4.0
|
The primary data structure in NEML2 for working with labeled tensor views. More...
The primary data structure in NEML2 for working with labeled tensor views.
Each LabeledTensor consists of one Tensor and one or more LabeledAxis. The LabeledTensor<D>
is templated on the base dimension \(D\). LabeledTensor handles the creation, modification, and accessing of labeled tensors.
D | The number of base dimensions |
#include <LabeledTensor.h>
Public Member Functions | |
LabeledTensor ()=default | |
Default constructor. | |
LabeledTensor (const torch::Tensor &tensor, const std::array< const LabeledAxis *, D > &axes) | |
Construct from a torch::Tensor and array of LabeledAxis | |
LabeledTensor (const Tensor &tensor, const std::array< const LabeledAxis *, D > &axes) | |
Construct from a Tensor with array of LabeledAxis | |
LabeledTensor (const Derived &other) | |
Copy constructor. | |
LabeledTensor< Derived, D > & | operator= (const Derived &other) |
Assignment operator. | |
template<typename T , typename = std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>> | |
T | reinterpret (indexing::TensorLabelsRef indices) const |
Get a tensor by slicing on the base dimensions AND reinterpret it as a primitive tensor. | |
const std::array< const LabeledAxis *, D > & | axes () const |
Get all the labeled axes. | |
const LabeledAxis & | axis (Size i=0) const |
Get a specific labeled axis. | |
operator Tensor () const | |
operator torch::Tensor () const | |
const Tensor & | tensor () const |
Tensor & | tensor () |
Meta operations | |
Derived | clone (torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const |
Derived | detach () const |
Return a copy without gradient graphs. | |
void | detach_ () |
Detach from gradient graphs. | |
Derived | to (const torch::TensorOptions &options) const |
Change tensor options. | |
void | copy_ (const torch::Tensor &other) |
Copy another tensor. | |
void | zero_ () |
Set all entries to zero. | |
bool | requires_grad () const |
Get the requires_grad property. | |
void | requires_grad_ (bool req=true) |
Set the requires_grad property. | |
Derived | operator- () const |
Negation. | |
Tensor information | |
torch::TensorOptions | options () const |
torch::Dtype | scalar_type () const |
Tensor options. | |
torch::Device | device () const |
Tensor options. | |
Size | dim () const |
Number of tensor dimensions. | |
TensorShapeRef | sizes () const |
Tensor shape. | |
Size | size (Size dim) const |
Tensor shape. | |
bool | batched () const |
Whether the tensor is batched. | |
Size | batch_dim () const |
Return the number of batch dimensions. | |
constexpr Size | base_dim () const |
Return the number of base dimensions. | |
TensorShapeRef | batch_sizes () const |
Return the batch size. | |
Size | batch_size (Size d) const |
Return the length of some batch axis. | |
TensorShapeRef | base_sizes () const |
Return the base size. | |
Size | base_size (Size d) const |
Return the length of some base axis. | |
Size | base_storage () const |
Return the flattened storage needed just for the base indices. | |
Getter and setter | |
Derived | batch_index (indexing::TensorIndicesRef indices) const |
Tensor | base_index (indexing::TensorLabelsRef labels) const |
Get a tensor by slicing on the base dimensions. | |
void | batch_index_put_ (indexing::TensorIndicesRef indices, const torch::Tensor &other) |
Set values by slicing on the batch dimensions. | |
void | base_index_put_ (indexing::TensorLabelsRef labels, const Tensor &other) |
Set values by slicing on the base dimensions. | |
Static Public Member Functions | |
static Derived | empty (TensorShapeRef batch_shape, const std::array< const LabeledAxis *, D > &axes, const torch::TensorOptions &options=default_tensor_options()) |
Setup new empty storage. | |
static Derived | zeros (TensorShapeRef batch_shape, const std::array< const LabeledAxis *, D > &axes, const torch::TensorOptions &options=default_tensor_options()) |
Setup new storage with zeros. | |
Protected Attributes | |
Tensor | _tensor |
The tensor. | |
std::array< const LabeledAxis *, D > | _axes |
The labeled axes of this tensor. | |
|
default |
Default constructor.
LabeledTensor | ( | const torch::Tensor & | tensor, |
const std::array< const LabeledAxis *, D > & | axes ) |
Construct from a torch::Tensor and array of LabeledAxis
LabeledTensor | ( | const Tensor & | tensor, |
const std::array< const LabeledAxis *, D > & | axes ) |
Construct from a Tensor with array of LabeledAxis
Get all the labeled axes.
Get a specific labeled axis.
Return the number of base dimensions.
Tensor base_index | ( | indexing::TensorLabelsRef | labels | ) | const |
Get a tensor by slicing on the base dimensions.
void base_index_put_ | ( | indexing::TensorLabelsRef | labels, |
const Tensor & | other ) |
Set values by slicing on the base dimensions.
Return the length of some base axis.
TensorShapeRef base_sizes | ( | ) | const |
Return the base size.
Return the flattened storage needed just for the base indices.
Derived batch_index | ( | indexing::TensorIndicesRef | indices | ) | const |
Get a tensor by slicing on the batch dimensions
void batch_index_put_ | ( | indexing::TensorIndicesRef | indices, |
const torch::Tensor & | other ) |
Set values by slicing on the batch dimensions.
Return the length of some batch axis.
TensorShapeRef batch_sizes | ( | ) | const |
Return the batch size.
Derived clone | ( | torch::MemoryFormat | memory_format = torch::MemoryFormat::Contiguous | ) | const |
Clone this LabeledTensor
|
static |
Setup new empty storage.
Assignment operator.
T reinterpret | ( | indexing::TensorLabelsRef | indices | ) | const |
Get a tensor by slicing on the base dimensions AND reinterpret it as a primitive tensor.
Set the requires_grad property.
TensorShapeRef sizes | ( | ) | const |
Tensor shape.
Change tensor options.
|
static |
Setup new storage with zeros.
The labeled axes of this tensor.