NEML2 1.4.0
Loading...
Searching...
No Matches
LabeledTensor< Derived, D > Class Template Reference

The primary data structure in NEML2 for working with labeled tensor views. More...

Detailed Description

template<class Derived, Size D>
class neml2::LabeledTensor< Derived, D >

The primary data structure in NEML2 for working with labeled tensor views.

Each LabeledTensor consists of one Tensor and one or more LabeledAxis. The LabeledTensor<D> is templated on the base dimension \(D\). LabeledTensor handles the creation, modification, and accessing of labeled tensors.

Template Parameters
DThe number of base dimensions

#include <LabeledTensor.h>

Public Member Functions

 LabeledTensor ()=default
 Default constructor.
 
 LabeledTensor (const torch::Tensor &tensor, const std::array< const LabeledAxis *, D > &axes)
 Construct from a torch::Tensor and array of LabeledAxis
 
 LabeledTensor (const Tensor &tensor, const std::array< const LabeledAxis *, D > &axes)
 Construct from a Tensor with array of LabeledAxis
 
 LabeledTensor (const Derived &other)
 Copy constructor.
 
LabeledTensor< Derived, D > & operator= (const Derived &other)
 Assignment operator.
 
template<typename T , typename = std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
reinterpret (indexing::TensorLabelsRef indices) const
 Get a tensor by slicing on the base dimensions AND reinterpret it as a primitive tensor.
 
const std::array< const LabeledAxis *, D > & axes () const
 Get all the labeled axes.
 
const LabeledAxisaxis (Size i=0) const
 Get a specific labeled axis.
 
 operator Tensor () const
 
 operator torch::Tensor () const
 
const Tensortensor () const
 
Tensortensor ()
 
Meta operations
Derived clone (torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
 
Derived detach () const
 Return a copy without gradient graphs.
 
void detach_ ()
 Detach from gradient graphs.
 
Derived to (const torch::TensorOptions &options) const
 Change tensor options.
 
void copy_ (const torch::Tensor &other)
 Copy another tensor.
 
void zero_ ()
 Set all entries to zero.
 
bool requires_grad () const
 Get the requires_grad property.
 
void requires_grad_ (bool req=true)
 Set the requires_grad property.
 
Derived operator- () const
 Negation.
 
Tensor information
torch::TensorOptions options () const
 
torch::Dtype scalar_type () const
 Tensor options.
 
torch::Device device () const
 Tensor options.
 
Size dim () const
 Number of tensor dimensions.
 
TensorShapeRef sizes () const
 Tensor shape.
 
Size size (Size dim) const
 Tensor shape.
 
bool batched () const
 Whether the tensor is batched.
 
Size batch_dim () const
 Return the number of batch dimensions.
 
constexpr Size base_dim () const
 Return the number of base dimensions.
 
TensorShapeRef batch_sizes () const
 Return the batch size.
 
Size batch_size (Size d) const
 Return the length of some batch axis.
 
TensorShapeRef base_sizes () const
 Return the base size.
 
Size base_size (Size d) const
 Return the length of some base axis.
 
Size base_storage () const
 Return the flattened storage needed just for the base indices.
 
Getter and setter
Derived batch_index (indexing::TensorIndicesRef indices) const
 
Tensor base_index (indexing::TensorLabelsRef labels) const
 Get a tensor by slicing on the base dimensions.
 
void batch_index_put_ (indexing::TensorIndicesRef indices, const torch::Tensor &other)
 Set values by slicing on the batch dimensions.
 
void base_index_put_ (indexing::TensorLabelsRef labels, const Tensor &other)
 Set values by slicing on the base dimensions.
 

Static Public Member Functions

static Derived empty (TensorShapeRef batch_shape, const std::array< const LabeledAxis *, D > &axes, const torch::TensorOptions &options=default_tensor_options())
 Setup new empty storage.
 
static Derived zeros (TensorShapeRef batch_shape, const std::array< const LabeledAxis *, D > &axes, const torch::TensorOptions &options=default_tensor_options())
 Setup new storage with zeros.
 

Protected Attributes

Tensor _tensor
 The tensor.
 
std::array< const LabeledAxis *, D_axes
 The labeled axes of this tensor.
 

Constructor & Destructor Documentation

◆ LabeledTensor() [1/4]

template<class Derived , Size D>
LabeledTensor ( )
default

Default constructor.

◆ LabeledTensor() [2/4]

template<class Derived , Size D>
LabeledTensor ( const torch::Tensor & tensor,
const std::array< const LabeledAxis *, D > & axes )

Construct from a torch::Tensor and array of LabeledAxis

◆ LabeledTensor() [3/4]

template<class Derived , Size D>
LabeledTensor ( const Tensor & tensor,
const std::array< const LabeledAxis *, D > & axes )

Construct from a Tensor with array of LabeledAxis

◆ LabeledTensor() [4/4]

template<class Derived , Size D>
LabeledTensor ( const Derived & other)

Copy constructor.

Member Function Documentation

◆ axes()

template<class Derived , Size D>
const std::array< const LabeledAxis *, D > & axes ( ) const
inline

Get all the labeled axes.

◆ axis()

template<class Derived , Size D>
const LabeledAxis & axis ( Size i = 0) const
inline

Get a specific labeled axis.

◆ base_dim()

template<class Derived , Size D>
constexpr Size base_dim ( ) const
inlineconstexpr

Return the number of base dimensions.

◆ base_index()

template<class Derived , Size D>
Tensor base_index ( indexing::TensorLabelsRef labels) const

Get a tensor by slicing on the base dimensions.

◆ base_index_put_()

template<class Derived , Size D>
void base_index_put_ ( indexing::TensorLabelsRef labels,
const Tensor & other )

Set values by slicing on the base dimensions.

◆ base_size()

template<class Derived , Size D>
Size base_size ( Size d) const

Return the length of some base axis.

◆ base_sizes()

template<class Derived , Size D>
TensorShapeRef base_sizes ( ) const

Return the base size.

◆ base_storage()

template<class Derived , Size D>
Size base_storage ( ) const

Return the flattened storage needed just for the base indices.

◆ batch_dim()

template<class Derived , Size D>
Size batch_dim ( ) const

Return the number of batch dimensions.

◆ batch_index()

template<class Derived , Size D>
Derived batch_index ( indexing::TensorIndicesRef indices) const

Get a tensor by slicing on the batch dimensions

◆ batch_index_put_()

template<class Derived , Size D>
void batch_index_put_ ( indexing::TensorIndicesRef indices,
const torch::Tensor & other )

Set values by slicing on the batch dimensions.

◆ batch_size()

template<class Derived , Size D>
Size batch_size ( Size d) const

Return the length of some batch axis.

◆ batch_sizes()

template<class Derived , Size D>
TensorShapeRef batch_sizes ( ) const

Return the batch size.

◆ batched()

template<class Derived , Size D>
bool batched ( ) const

Whether the tensor is batched.

◆ clone()

template<class Derived , Size D>
Derived clone ( torch::MemoryFormat memory_format = torch::MemoryFormat::Contiguous) const

Clone this LabeledTensor

◆ copy_()

template<class Derived , Size D>
void copy_ ( const torch::Tensor & other)

Copy another tensor.

◆ detach()

template<class Derived , Size D>
Derived detach ( ) const

Return a copy without gradient graphs.

◆ detach_()

template<class Derived , Size D>
void detach_ ( )

Detach from gradient graphs.

◆ device()

template<class Derived , Size D>
torch::Device device ( ) const

Tensor options.

◆ dim()

template<class Derived , Size D>
Size dim ( ) const

Number of tensor dimensions.

◆ empty()

template<class Derived , Size D>
Derived empty ( TensorShapeRef batch_shape,
const std::array< const LabeledAxis *, D > & axes,
const torch::TensorOptions & options = default_tensor_options() )
static

Setup new empty storage.

◆ operator Tensor()

template<class Derived , Size D>
operator Tensor ( ) const

Implicit conversion

◆ operator torch::Tensor()

template<class Derived , Size D>
operator torch::Tensor ( ) const

◆ operator-()

template<class Derived , Size D>
Derived operator- ( ) const

Negation.

◆ operator=()

template<class Derived , Size D>
LabeledTensor< Derived, D > & operator= ( const Derived & other)

Assignment operator.

◆ options()

template<class Derived , Size D>
torch::TensorOptions options ( ) const

Tensor options

◆ reinterpret()

template<class Derived , Size D>
template<typename T , typename >
T reinterpret ( indexing::TensorLabelsRef indices) const

Get a tensor by slicing on the base dimensions AND reinterpret it as a primitive tensor.

◆ requires_grad()

template<class Derived , Size D>
bool requires_grad ( ) const

Get the requires_grad property.

◆ requires_grad_()

template<class Derived , Size D>
void requires_grad_ ( bool req = true)

Set the requires_grad property.

◆ scalar_type()

template<class Derived , Size D>
torch::Dtype scalar_type ( ) const

Tensor options.

◆ size()

template<class Derived , Size D>
Size size ( Size dim) const

Tensor shape.

◆ sizes()

template<class Derived , Size D>
TensorShapeRef sizes ( ) const

Tensor shape.

◆ tensor() [1/2]

template<class Derived , Size D>
Tensor & tensor ( )
inline

◆ tensor() [2/2]

template<class Derived , Size D>
const Tensor & tensor ( ) const
inline

Get the underlying tensor

◆ to()

template<class Derived , Size D>
Derived to ( const torch::TensorOptions & options) const

Change tensor options.

◆ zero_()

template<class Derived , Size D>
void zero_ ( )

Set all entries to zero.

◆ zeros()

template<class Derived , Size D>
Derived zeros ( TensorShapeRef batch_shape,
const std::array< const LabeledAxis *, D > & axes,
const torch::TensorOptions & options = default_tensor_options() )
static

Setup new storage with zeros.

Member Data Documentation

◆ _axes

template<class Derived , Size D>
std::array<const LabeledAxis *, D> _axes
protected

The labeled axes of this tensor.

◆ _tensor

template<class Derived , Size D>
Tensor _tensor
protected

The tensor.