NEML2 1.4.0
Loading...
Searching...
No Matches
LabeledTensor< Derived, D > Class Template Reference

The primary data structure in NEML2 for working with labeled tensor views. More...

Detailed Description

template<class Derived, TorchSize D>
class neml2::LabeledTensor< Derived, D >

The primary data structure in NEML2 for working with labeled tensor views.

Each LabeledTensor consists of one BatchTensor and one or more LabeledAxis. The LabeledTensor<D> is templated on the base dimension \(D\). LabeledTensor handles the creation, modification, and accessing of labeled tensors.

Template Parameters
DThe number of base dimensions

#include <LabeledTensor.h>

Classes

struct  variable_type
 Template setup for appropriate variable types. More...
 

Public Member Functions

 LabeledTensor ()=default
 Default constructor.
 
 LabeledTensor (const torch::Tensor &tensor, TorchSize batch_dim, const std::vector< const LabeledAxis * > &axes)
 Construct from a Tensor with batch dim and vector of LabeledAxis
 
 LabeledTensor (const BatchTensor &tensor, const std::vector< const LabeledAxis * > &axes)
 Construct from a BatchTensor with vector of LabeledAxis
 
 LabeledTensor (const Derived &other)
 Copy constructor.
 
void operator= (const Derived &other)
 Assignment operator.
 
 operator BatchTensor () const
 A potentially dangerous implicit conversion.
 
 operator torch::Tensor () const
 A potentially dangerous implicit conversion.
 
Derived clone (torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
 Clone this LabeledTensor.
 
template<typename T >
void copy_ (const T &other)
 Copy the value from another tensor.
 
Derived detach () const
 Return a copy without gradient graphs.
 
void detach_ ()
 Detach from gradient graphs.
 
void zero_ ()
 Zero out this tensor.
 
torch::TensorOptions options () const
 Get the tensor options.
 
TorchSize batch_dim () const
 Return the number of batch dimensions.
 
TorchSize base_dim () const
 Return the number of base dimensions.
 
TorchShapeRef batch_sizes () const
 Return the batch size.
 
TorchShapeRef base_sizes () const
 Return the base size.
 
const std::vector< const LabeledAxis * > & axes () const
 Get all the labeled axes.
 
const LabeledAxisaxis (TorchSize i=0) const
 Get a specific labeled axis.
 
template<typename... S>
TorchSlice slice_indices (S &&... names) const
 How to slice the tensor given the names on each axis.
 
TorchShapeRef storage_size () const
 The shape of the entire LabeledTensor.
 
template<typename... S>
TorchShape storage_size (S &&... names) const
 The shape of a sub-block specified by the names on each dimension.
 
template<typename... S>
BatchTensor operator() (S &&... names) const
 
Derived slice (TorchSize i, const std::string &name) const
 Slice the tensor on the given dimension by a single variable or sub-axis.
 
template<typename... S>
Derived block (S &&... names) const
 Get the sub-block labeled by the given sub-axis names.
 
Derived batch_index (TorchSlice indices) const
 Get a batch.
 
void batch_index_put (TorchSlice indices, const torch::Tensor &other)
 Set a index sliced on the batch dimensions to a value.
 
BatchTensor base_index (TorchSlice indices) const
 Return an index sliced on the batch dimensions.
 
void base_index_put (TorchSlice indices, const torch::Tensor &other)
 Set a index sliced on the batch dimensions to a value.
 
template<typename T , typename... S>
variable_type< T >::type get (S &&... names) const
 Get and interpret the view as an object.
 
template<typename T , typename... S>
variable_type< T >::type get_list (S &&... names) const
 Get and interpret the view as a list of objects.
 
template<typename T , typename... S>
void set (const BatchTensorBase< T > &value, S &&... names)
 Set and interpret the input as an object.
 
template<typename T , typename... S>
void set_list (const BatchTensorBase< T > &value, S &&... names)
 Set and interpret the input as a list of objects.
 
Derived operator- () const
 Negation.
 
Derived to (const torch::TensorOptions &options) const
 Change tensor options.
 
const BatchTensortensor () const
 
BatchTensortensor ()
 

Static Public Member Functions

static Derived empty (TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options())
 Setup new empty storage.
 
static Derived empty_like (const Derived &other)
 Setup new empty storage like another LabeledTensor.
 
static Derived zeros (TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options())
 Setup new storage with zeros.
 
static Derived zeros_like (const Derived &other)
 Setup new storage with zeros like another LabeledTensor.
 

Protected Attributes

BatchTensor _tensor
 The tensor.
 
std::vector< const LabeledAxis * > _axes
 The labeled axes of this tensor.
 

Constructor & Destructor Documentation

◆ LabeledTensor() [1/4]

template<class Derived , TorchSize D>
LabeledTensor ( )
default

Default constructor.

◆ LabeledTensor() [2/4]

template<class Derived , TorchSize D>
LabeledTensor ( const torch::Tensor & tensor,
TorchSize batch_dim,
const std::vector< const LabeledAxis * > & axes )

Construct from a Tensor with batch dim and vector of LabeledAxis

◆ LabeledTensor() [3/4]

template<class Derived , TorchSize D>
LabeledTensor ( const BatchTensor & tensor,
const std::vector< const LabeledAxis * > & axes )

Construct from a BatchTensor with vector of LabeledAxis

◆ LabeledTensor() [4/4]

template<class Derived , TorchSize D>
LabeledTensor ( const Derived & other)

Copy constructor.

Member Function Documentation

◆ axes()

template<class Derived , TorchSize D>
const std::vector< const LabeledAxis * > & axes ( ) const
inline

Get all the labeled axes.

◆ axis()

template<class Derived , TorchSize D>
const LabeledAxis & axis ( TorchSize i = 0) const
inline

Get a specific labeled axis.

◆ base_dim()

template<class Derived , TorchSize D>
TorchSize base_dim ( ) const

Return the number of base dimensions.

◆ base_index()

template<class Derived , TorchSize D>
BatchTensor base_index ( TorchSlice indices) const

Return an index sliced on the batch dimensions.

◆ base_index_put()

template<class Derived , TorchSize D>
void base_index_put ( TorchSlice indices,
const torch::Tensor & other )

Set a index sliced on the batch dimensions to a value.

◆ base_sizes()

template<class Derived , TorchSize D>
TorchShapeRef base_sizes ( ) const

Return the base size.

◆ batch_dim()

template<class Derived , TorchSize D>
TorchSize batch_dim ( ) const

Return the number of batch dimensions.

◆ batch_index()

template<class Derived , TorchSize D>
Derived batch_index ( TorchSlice indices) const

Get a batch.

◆ batch_index_put()

template<class Derived , TorchSize D>
void batch_index_put ( TorchSlice indices,
const torch::Tensor & other )

Set a index sliced on the batch dimensions to a value.

◆ batch_sizes()

template<class Derived , TorchSize D>
TorchShapeRef batch_sizes ( ) const

Return the batch size.

◆ block()

template<class Derived , TorchSize D>
template<typename... S>
Derived block ( S &&... names) const

Get the sub-block labeled by the given sub-axis names.

◆ clone()

template<class Derived , TorchSize D>
Derived clone ( torch::MemoryFormat memory_format = torch::MemoryFormat::Contiguous) const

Clone this LabeledTensor.

◆ copy_()

template<class Derived , TorchSize D>
template<typename T >
void copy_ ( const T & other)

Copy the value from another tensor.

◆ detach()

template<class Derived , TorchSize D>
Derived detach ( ) const

Return a copy without gradient graphs.

◆ detach_()

template<class Derived , TorchSize D>
void detach_ ( )

Detach from gradient graphs.

◆ empty()

template<class Derived , TorchSize D>
Derived empty ( TorchShapeRef batch_shape,
const std::vector< const LabeledAxis * > & axes,
const torch::TensorOptions & options = default_tensor_options() )
static

Setup new empty storage.

◆ empty_like()

template<class Derived , TorchSize D>
Derived empty_like ( const Derived & other)
static

Setup new empty storage like another LabeledTensor.

◆ get()

template<class Derived , TorchSize D>
template<typename T , typename... S>
variable_type< T >::type get ( S &&... names) const
inline

Get and interpret the view as an object.

◆ get_list()

template<class Derived , TorchSize D>
template<typename T , typename... S>
variable_type< T >::type get_list ( S &&... names) const
inline

Get and interpret the view as a list of objects.

◆ operator BatchTensor()

template<class Derived , TorchSize D>
operator BatchTensor ( ) const

A potentially dangerous implicit conversion.

◆ operator torch::Tensor()

template<class Derived , TorchSize D>
operator torch::Tensor ( ) const

A potentially dangerous implicit conversion.

◆ operator()()

template<class Derived , TorchSize D>
template<typename... S>
BatchTensor operator() ( S &&... names) const

Return a labeled view into the tensor. No reshaping is done.

◆ operator-()

template<class Derived , TorchSize D>
Derived operator- ( ) const

Negation.

◆ operator=()

template<class Derived , TorchSize D>
void operator= ( const Derived & other)

Assignment operator.

◆ options()

template<class Derived , TorchSize D>
torch::TensorOptions options ( ) const
inline

Get the tensor options.

◆ set()

template<class Derived , TorchSize D>
template<typename T , typename... S>
void set ( const BatchTensorBase< T > & value,
S &&... names )
inline

Set and interpret the input as an object.

◆ set_list()

template<class Derived , TorchSize D>
template<typename T , typename... S>
void set_list ( const BatchTensorBase< T > & value,
S &&... names )
inline

Set and interpret the input as a list of objects.

◆ slice()

template<class Derived , TorchSize D>
Derived slice ( TorchSize i,
const std::string & name ) const

Slice the tensor on the given dimension by a single variable or sub-axis.

◆ slice_indices()

template<class Derived , TorchSize D>
template<typename... S>
TorchSlice slice_indices ( S &&... names) const

How to slice the tensor given the names on each axis.

◆ storage_size() [1/2]

template<class Derived , TorchSize D>
TorchShapeRef storage_size ( ) const

The shape of the entire LabeledTensor.

◆ storage_size() [2/2]

template<class Derived , TorchSize D>
template<typename... S>
TorchShape storage_size ( S &&... names) const

The shape of a sub-block specified by the names on each dimension.

◆ tensor() [1/2]

template<class Derived , TorchSize D>
BatchTensor & tensor ( )
inline

◆ tensor() [2/2]

template<class Derived , TorchSize D>
const BatchTensor & tensor ( ) const
inline

Get the underlying tensor

◆ to()

template<class Derived , TorchSize D>
Derived to ( const torch::TensorOptions & options) const

Change tensor options.

◆ zero_()

template<class Derived , TorchSize D>
void zero_ ( )

Zero out this tensor.

◆ zeros()

template<class Derived , TorchSize D>
Derived zeros ( TorchShapeRef batch_shape,
const std::vector< const LabeledAxis * > & axes,
const torch::TensorOptions & options = default_tensor_options() )
static

Setup new storage with zeros.

◆ zeros_like()

template<class Derived , TorchSize D>
Derived zeros_like ( const Derived & other)
static

Setup new storage with zeros like another LabeledTensor.

Member Data Documentation

◆ _axes

template<class Derived , TorchSize D>
std::vector<const LabeledAxis *> _axes
protected

The labeled axes of this tensor.

◆ _tensor

template<class Derived , TorchSize D>
BatchTensor _tensor
protected

The tensor.