NEML2 1.4.0
|
The primary data structure in NEML2 for working with labeled tensor views. More...
The primary data structure in NEML2 for working with labeled tensor views.
Each LabeledTensor consists of one BatchTensor and one or more LabeledAxis. The LabeledTensor<D>
is templated on the base dimension \(D\). LabeledTensor handles the creation, modification, and accessing of labeled tensors.
D | The number of base dimensions |
#include <LabeledTensor.h>
Classes | |
struct | variable_type |
Template setup for appropriate variable types. More... | |
Public Member Functions | |
LabeledTensor ()=default | |
Default constructor. | |
LabeledTensor (const torch::Tensor &tensor, TorchSize batch_dim, const std::vector< const LabeledAxis * > &axes) | |
Construct from a Tensor with batch dim and vector of LabeledAxis | |
LabeledTensor (const BatchTensor &tensor, const std::vector< const LabeledAxis * > &axes) | |
Construct from a BatchTensor with vector of LabeledAxis | |
LabeledTensor (const Derived &other) | |
Copy constructor. | |
void | operator= (const Derived &other) |
Assignment operator. | |
operator BatchTensor () const | |
A potentially dangerous implicit conversion. | |
operator torch::Tensor () const | |
A potentially dangerous implicit conversion. | |
Derived | clone (torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const |
Clone this LabeledTensor. | |
template<typename T > | |
void | copy_ (const T &other) |
Copy the value from another tensor. | |
Derived | detach () const |
Return a copy without gradient graphs. | |
void | detach_ () |
Detach from gradient graphs. | |
void | zero_ () |
Zero out this tensor. | |
torch::TensorOptions | options () const |
Get the tensor options. | |
TorchSize | batch_dim () const |
Return the number of batch dimensions. | |
TorchSize | base_dim () const |
Return the number of base dimensions. | |
TorchShapeRef | batch_sizes () const |
Return the batch size. | |
TorchShapeRef | base_sizes () const |
Return the base size. | |
const std::vector< const LabeledAxis * > & | axes () const |
Get all the labeled axes. | |
const LabeledAxis & | axis (TorchSize i=0) const |
Get a specific labeled axis. | |
template<typename... S> | |
TorchSlice | slice_indices (S &&... names) const |
How to slice the tensor given the names on each axis. | |
TorchShapeRef | storage_size () const |
The shape of the entire LabeledTensor. | |
template<typename... S> | |
TorchShape | storage_size (S &&... names) const |
The shape of a sub-block specified by the names on each dimension. | |
template<typename... S> | |
BatchTensor | operator() (S &&... names) const |
Derived | slice (TorchSize i, const std::string &name) const |
Slice the tensor on the given dimension by a single variable or sub-axis. | |
template<typename... S> | |
Derived | block (S &&... names) const |
Get the sub-block labeled by the given sub-axis names. | |
Derived | batch_index (TorchSlice indices) const |
Get a batch. | |
void | batch_index_put (TorchSlice indices, const torch::Tensor &other) |
Set a index sliced on the batch dimensions to a value. | |
BatchTensor | base_index (TorchSlice indices) const |
Return an index sliced on the batch dimensions. | |
void | base_index_put (TorchSlice indices, const torch::Tensor &other) |
Set a index sliced on the batch dimensions to a value. | |
template<typename T , typename... S> | |
variable_type< T >::type | get (S &&... names) const |
Get and interpret the view as an object. | |
template<typename T , typename... S> | |
variable_type< T >::type | get_list (S &&... names) const |
Get and interpret the view as a list of objects. | |
template<typename T , typename... S> | |
void | set (const BatchTensorBase< T > &value, S &&... names) |
Set and interpret the input as an object. | |
template<typename T , typename... S> | |
void | set_list (const BatchTensorBase< T > &value, S &&... names) |
Set and interpret the input as a list of objects. | |
Derived | operator- () const |
Negation. | |
Derived | to (const torch::TensorOptions &options) const |
Change tensor options. | |
const BatchTensor & | tensor () const |
BatchTensor & | tensor () |
Static Public Member Functions | |
static Derived | empty (TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options()) |
Setup new empty storage. | |
static Derived | empty_like (const Derived &other) |
Setup new empty storage like another LabeledTensor. | |
static Derived | zeros (TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options()) |
Setup new storage with zeros. | |
static Derived | zeros_like (const Derived &other) |
Setup new storage with zeros like another LabeledTensor. | |
Protected Attributes | |
BatchTensor | _tensor |
The tensor. | |
std::vector< const LabeledAxis * > | _axes |
The labeled axes of this tensor. | |
|
default |
Default constructor.
LabeledTensor | ( | const torch::Tensor & | tensor, |
TorchSize | batch_dim, | ||
const std::vector< const LabeledAxis * > & | axes ) |
Construct from a Tensor with batch dim and vector of LabeledAxis
LabeledTensor | ( | const BatchTensor & | tensor, |
const std::vector< const LabeledAxis * > & | axes ) |
Construct from a BatchTensor with vector of LabeledAxis
|
inline |
Get all the labeled axes.
Get a specific labeled axis.
Return the number of base dimensions.
BatchTensor base_index | ( | TorchSlice | indices | ) | const |
Return an index sliced on the batch dimensions.
void base_index_put | ( | TorchSlice | indices, |
const torch::Tensor & | other ) |
Set a index sliced on the batch dimensions to a value.
TorchShapeRef base_sizes | ( | ) | const |
Return the base size.
Return the number of batch dimensions.
Derived batch_index | ( | TorchSlice | indices | ) | const |
Get a batch.
void batch_index_put | ( | TorchSlice | indices, |
const torch::Tensor & | other ) |
Set a index sliced on the batch dimensions to a value.
TorchShapeRef batch_sizes | ( | ) | const |
Return the batch size.
Get the sub-block labeled by the given sub-axis names.
Derived clone | ( | torch::MemoryFormat | memory_format = torch::MemoryFormat::Contiguous | ) | const |
Clone this LabeledTensor.
Copy the value from another tensor.
Return a copy without gradient graphs.
|
static |
Setup new empty storage.
Setup new empty storage like another LabeledTensor.
|
inline |
Get and interpret the view as an object.
|
inline |
Get and interpret the view as a list of objects.
operator BatchTensor | ( | ) | const |
A potentially dangerous implicit conversion.
A potentially dangerous implicit conversion.
BatchTensor operator() | ( | S &&... | names | ) | const |
Return a labeled view into the tensor. No reshaping is done.
Get the tensor options.
|
inline |
Set and interpret the input as an object.
|
inline |
Set and interpret the input as a list of objects.
Slice the tensor on the given dimension by a single variable or sub-axis.
TorchSlice slice_indices | ( | S &&... | names | ) | const |
How to slice the tensor given the names on each axis.
TorchShapeRef storage_size | ( | ) | const |
The shape of the entire LabeledTensor.
TorchShape storage_size | ( | S &&... | names | ) | const |
The shape of a sub-block specified by the names on each dimension.
|
inline |
|
inline |
Get the underlying tensor
Change tensor options.
|
static |
Setup new storage with zeros.
Setup new storage with zeros like another LabeledTensor.
|
protected |
The labeled axes of this tensor.
|
protected |
The tensor.