NEML2 1.4.0
Loading...
Searching...
No Matches
TensorBase< Derived > Class Template Reference

NEML2's enhanced tensor type. More...

Detailed Description

template<class Derived>
class neml2::TensorBase< Derived >

NEML2's enhanced tensor type.

neml2::TensorBase derives from torch::Tensor and clearly distinguishes between "batched" dimensions from other dimensions. The shape of the "batched" dimensions is called the batch size, and the shape of the rest dimensions is called the base size.

#include <TensorBase.h>

Inheritance diagram for TensorBase< Derived >:

Public Member Functions

 TensorBase ()=default
 Default constructor.
 
 TensorBase (const torch::Tensor &tensor, Size batch_dim)
 Construct from another torch::Tensor.
 
 TensorBase (const Derived &tensor)
 Copy constructor.
 
 TensorBase (Real)=delete
 
Meta operations
Derived clone (torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
 
Derived detach () const
 Discard function graph.
 
Derived to (const torch::TensorOptions &options) const
 Change tensor options.
 
Derived operator- () const
 Negation.
 
Tensor information
bool batched () const
 Whether the tensor is batched.
 
Size batch_dim () const
 Return the number of batch dimensions.
 
Size base_dim () const
 Return the number of base dimensions.
 
TensorShapeRef batch_sizes () const
 Return the batch size.
 
Size batch_size (Size index) const
 Return the size of a batch axis.
 
TensorShapeRef base_sizes () const
 Return the base size.
 
Size base_size (Size index) const
 Return the size of a base axis.
 
Size base_storage () const
 Return the flattened storage needed just for the base indices.
 
Getter and setter
Derived batch_index (indexing::TensorIndicesRef indices) const
 Get a tensor by slicing on the batch dimensions.
 
neml2::Tensor base_index (indexing::TensorIndicesRef indices) const
 Get a tensor by slicing on the base dimensions.
 
void batch_index_put_ (indexing::TensorIndicesRef indices, const torch::Tensor &other)
 Set values by slicing on the batch dimensions.
 
void base_index_put_ (indexing::TensorIndicesRef indices, const torch::Tensor &other)
 Set values by slicing on the base dimensions.
 
Modifiers

Return a new view of the tensor with values broadcast along the batch dimensions.

Derived batch_expand (TensorShapeRef batch_size) const
 
neml2::Tensor base_expand (TensorShapeRef base_size) const
 Return a new view of the tensor with values broadcast along the base dimensions.
 
template<class Derived2 >
Derived batch_expand_as (const Derived2 &other) const
 Expand the batch to have the same shape as another tensor.
 
template<class Derived2 >
Derived2 base_expand_as (const Derived2 &other) const
 Expand the base to have the same shape as another tensor.
 
Derived batch_expand_copy (TensorShapeRef batch_size) const
 Return a new tensor with values broadcast along the batch dimensions.
 
neml2::Tensor base_expand_copy (TensorShapeRef base_size) const
 Return a new tensor with values broadcast along the base dimensions.
 
Derived batch_reshape (TensorShapeRef batch_shape) const
 Reshape batch dimensions.
 
neml2::Tensor base_reshape (TensorShapeRef base_shape) const
 Reshape base dimensions.
 
Derived batch_unsqueeze (Size d) const
 Unsqueeze a batch dimension.
 
neml2::Tensor base_unsqueeze (Size d) const
 Unsqueeze a base dimension.
 
Derived batch_transpose (Size d1, Size d2) const
 Transpose two batch dimensions.
 
neml2::Tensor base_transpose (Size d1, Size d2) const
 Transpose two base dimensions.
 

Static Public Member Functions

static Derived empty_like (const Derived &other)
 Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
 
static Derived zeros_like (const Derived &other)
 Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
 
static Derived ones_like (const Derived &other)
 Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
 
static Derived full_like (const Derived &other, Real init)
 
static Derived linspace (const Derived &start, const Derived &end, Size nstep, Size dim=0, Size batch_dim=-1)
 Create a new tensor by adding a new batch dimension with linear spacing between start and end.
 
static Derived logspace (const Derived &start, const Derived &end, Size nstep, Size dim=0, Size batch_dim=-1, Real base=10)
 log-space equivalent of the linspace named constructor
 

Constructor & Destructor Documentation

◆ TensorBase() [1/4]

template<class Derived >
TensorBase ( )
default

Default constructor.

◆ TensorBase() [2/4]

template<class Derived >
TensorBase ( const torch::Tensor & tensor,
Size batch_dim )

Construct from another torch::Tensor.

◆ TensorBase() [3/4]

template<class Derived >
TensorBase ( const Derived & tensor)

Copy constructor.

◆ TensorBase() [4/4]

template<class Derived >
TensorBase ( Real )
delete

Member Function Documentation

◆ base_dim()

template<class Derived >
Size base_dim ( ) const

Return the number of base dimensions.

◆ base_expand()

template<class Derived >
neml2::Tensor base_expand ( TensorShapeRef base_size) const

Return a new view of the tensor with values broadcast along the base dimensions.

◆ base_expand_as()

template<class Derived >
template<class Derived2 >
Derived2 base_expand_as ( const Derived2 & other) const

Expand the base to have the same shape as another tensor.

◆ base_expand_copy()

template<class Derived >
neml2::Tensor base_expand_copy ( TensorShapeRef base_size) const

Return a new tensor with values broadcast along the base dimensions.

◆ base_index()

template<class Derived >
neml2::Tensor base_index ( indexing::TensorIndicesRef indices) const

Get a tensor by slicing on the base dimensions.

◆ base_index_put_()

template<class Derived >
void base_index_put_ ( indexing::TensorIndicesRef indices,
const torch::Tensor & other )

Set values by slicing on the base dimensions.

◆ base_reshape()

template<class Derived >
neml2::Tensor base_reshape ( TensorShapeRef base_shape) const

Reshape base dimensions.

◆ base_size()

template<class Derived >
Size base_size ( Size index) const

Return the size of a base axis.

◆ base_sizes()

template<class Derived >
TensorShapeRef base_sizes ( ) const

Return the base size.

◆ base_storage()

template<class Derived >
Size base_storage ( ) const

Return the flattened storage needed just for the base indices.

◆ base_transpose()

template<class Derived >
neml2::Tensor base_transpose ( Size d1,
Size d2 ) const

Transpose two base dimensions.

◆ base_unsqueeze()

template<class Derived >
neml2::Tensor base_unsqueeze ( Size d) const

Unsqueeze a base dimension.

◆ batch_dim()

template<class Derived >
Size batch_dim ( ) const

Return the number of batch dimensions.

◆ batch_expand()

template<class Derived >
Derived batch_expand ( TensorShapeRef batch_size) const

◆ batch_expand_as()

template<class Derived >
template<class Derived2 >
Derived batch_expand_as ( const Derived2 & other) const

Expand the batch to have the same shape as another tensor.

◆ batch_expand_copy()

template<class Derived >
Derived batch_expand_copy ( TensorShapeRef batch_size) const

Return a new tensor with values broadcast along the batch dimensions.

◆ batch_index()

template<class Derived >
Derived batch_index ( indexing::TensorIndicesRef indices) const

Get a tensor by slicing on the batch dimensions.

◆ batch_index_put_()

template<class Derived >
void batch_index_put_ ( indexing::TensorIndicesRef indices,
const torch::Tensor & other )

Set values by slicing on the batch dimensions.

◆ batch_reshape()

template<class Derived >
Derived batch_reshape ( TensorShapeRef batch_shape) const

Reshape batch dimensions.

◆ batch_size()

template<class Derived >
Size batch_size ( Size index) const

Return the size of a batch axis.

◆ batch_sizes()

template<class Derived >
TensorShapeRef batch_sizes ( ) const

Return the batch size.

◆ batch_transpose()

template<class Derived >
Derived batch_transpose ( Size d1,
Size d2 ) const

Transpose two batch dimensions.

◆ batch_unsqueeze()

template<class Derived >
Derived batch_unsqueeze ( Size d) const

Unsqueeze a batch dimension.

◆ batched()

template<class Derived >
bool batched ( ) const

Whether the tensor is batched.

◆ clone()

template<class Derived >
Derived clone ( torch::MemoryFormat memory_format = torch::MemoryFormat::Contiguous) const

Clone (take ownership)

◆ detach()

template<class Derived >
Derived detach ( ) const

Discard function graph.

◆ empty_like()

template<class Derived >
Derived empty_like ( const Derived & other)
static

Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.

◆ full_like()

template<class Derived >
Derived full_like ( const Derived & other,
Real init )
static

Full tensor like another, i.e. same batch and base shapes, same tensor options, etc., but filled with a different value

◆ linspace()

template<class Derived >
Derived linspace ( const Derived & start,
const Derived & end,
Size nstep,
Size dim = 0,
Size batch_dim = -1 )
static

Create a new tensor by adding a new batch dimension with linear spacing between start and end.

start and end must be broadcastable. The new batch dimension will be added at the user-specified dimension dim which defaults to 0.

For example, if start has shape (3, 2; 5, 5) and end has shape (3, 1; 5, 5), then

linspace(start, end, 100, 1);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, Size batch_dim=-1)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBase.cxx:79

will have shape (3, 100, 2; 5, 5), note the location of the new dimension and the broadcasting.

Parameters
startThe starting tensor
endThe ending tensor
nstepThe number of steps with even spacing along the new dimension
dimWhere to insert the new dimension
batch_dimBatch dimension of the output
Returns
Tensor Linearly spaced tensor

◆ logspace()

template<class Derived >
Derived logspace ( const Derived & start,
const Derived & end,
Size nstep,
Size dim = 0,
Size batch_dim = -1,
Real base = 10 )
static

log-space equivalent of the linspace named constructor

◆ ones_like()

template<class Derived >
Derived ones_like ( const Derived & other)
static

Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.

◆ operator-()

template<class Derived >
Derived operator- ( ) const

Negation.

◆ to()

template<class Derived >
Derived to ( const torch::TensorOptions & options) const

Change tensor options.

◆ zeros_like()

template<class Derived >
Derived zeros_like ( const Derived & other)
static

Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.