NEML2 1.4.0
Loading...
Searching...
No Matches
BatchTensorBase< Derived > Class Template Reference

NEML2's enhanced tensor type. More...

Detailed Description

template<class Derived>
class neml2::BatchTensorBase< Derived >

NEML2's enhanced tensor type.

BatchTensorBase derives from torch::Tensor and clearly distinguishes between "batched" dimensions from other dimensions. The shape of the "batched" dimensions is called the batch size, and the shape of the rest dimensions is called the base size.

#include <BatchTensorBase.h>

Inheritance diagram for BatchTensorBase< Derived >:

Public Member Functions

 BatchTensorBase ()=default
 Default constructor.
 
 BatchTensorBase (const torch::Tensor &tensor, TorchSize batch_dim)
 Construct from another torch::Tensor.
 
 BatchTensorBase (const Derived &tensor)
 Copy constructor.
 
 BatchTensorBase (Real)=delete
 
bool batched () const
 Whether the tensor is batched.
 
TorchSize batch_dim () const
 Return the number of batch dimensions.
 
TorchSizebatch_dim ()
 Return a writable reference to the batch dimension.
 
TorchSize base_dim () const
 Return the number of base dimensions.
 
TorchShapeRef batch_sizes () const
 Return the batch size.
 
TorchSize batch_size (TorchSize index) const
 Return the length of some batch axis.
 
TorchShapeRef base_sizes () const
 Return the base size.
 
TorchSize base_size (TorchSize index) const
 Return the length of some base axis.
 
TorchSize base_storage () const
 Return the flattened storage needed just for the base indices.
 
Derived batch_index (TorchSlice indices) const
 Get a batch.
 
BatchTensor base_index (const TorchSlice &indices) const
 Return an index sliced on the base dimensions.
 
void batch_index_put (TorchSlice indices, const torch::Tensor &other)
 Set a index sliced on the batch dimensions to a value.
 
void base_index_put (const TorchSlice &indices, const torch::Tensor &other)
 Set a index sliced on the base dimensions to a value.
 
Derived batch_expand (TorchShapeRef batch_size) const
 Return a new view of the tensor with values broadcast along the batch dimensions.
 
BatchTensor base_expand (TorchShapeRef base_size) const
 Return a new view of the tensor with values broadcast along the base dimensions.
 
template<class Derived2 >
Derived batch_expand_as (const Derived2 &other) const
 Expand the batch to have the same shape as another tensor.
 
template<class Derived2 >
Derived2 base_expand_as (const Derived2 &other) const
 Expand the base to have the same shape as another tensor.
 
Derived batch_expand_copy (TorchShapeRef batch_size) const
 Return a new tensor with values broadcast along the batch dimensions.
 
BatchTensor base_expand_copy (TorchShapeRef base_size) const
 Return a new tensor with values broadcast along the base dimensions.
 
Derived batch_reshape (TorchShapeRef batch_shape) const
 Reshape batch dimensions.
 
BatchTensor base_reshape (TorchShapeRef base_shape) const
 Reshape base dimensions.
 
Derived batch_unsqueeze (TorchSize d) const
 Unsqueeze a batch dimension.
 
Derived list_unsqueeze () const
 Unsqueeze on the special list batch dimension.
 
BatchTensor base_unsqueeze (TorchSize d) const
 Unsqueeze a base dimension.
 
Derived batch_transpose (TorchSize d1, TorchSize d2) const
 Transpose two batch dimensions.
 
BatchTensor base_transpose (TorchSize d1, TorchSize d2) const
 Transpose two base dimensions.
 
BatchTensor base_movedim (TorchSize d1, TorchSize d2) const
 Move two base dimensions.
 
Derived clone (torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
 Clone (take ownership)
 
Derived detach () const
 Discard function graph.
 
Derived to (const torch::TensorOptions &options) const
 Send to options.
 
Derived operator- () const
 Negation.
 
Derived batch_sum (TorchSize d) const
 Sum on a batch index.
 
Derived list_sum () const
 Sum on the list index (TODO: replace with class)
 

Static Public Member Functions

static Derived empty_like (const Derived &other)
 Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
 
static Derived zeros_like (const Derived &other)
 Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
 
static Derived ones_like (const Derived &other)
 Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
 
static Derived full_like (const Derived &other, Real init)
 
static Derived linspace (const Derived &start, const Derived &end, TorchSize nstep, TorchSize dim=0, TorchSize batch_dim=-1)
 Create a new tensor by adding a new batch dimension with linear spacing between start and end.
 
static Derived logspace (const Derived &start, const Derived &end, TorchSize nstep, TorchSize dim=0, TorchSize batch_dim=-1, Real base=10)
 log-space equivalent of the linspace named constructor
 

Constructor & Destructor Documentation

◆ BatchTensorBase() [1/4]

template<class Derived >
BatchTensorBase ( )
default

Default constructor.

◆ BatchTensorBase() [2/4]

template<class Derived >
BatchTensorBase ( const torch::Tensor & tensor,
TorchSize batch_dim )

Construct from another torch::Tensor.

◆ BatchTensorBase() [3/4]

template<class Derived >
BatchTensorBase ( const Derived & tensor)

Copy constructor.

◆ BatchTensorBase() [4/4]

template<class Derived >
BatchTensorBase ( Real )
delete

Member Function Documentation

◆ base_dim()

template<class Derived >
TorchSize base_dim ( ) const

Return the number of base dimensions.

◆ base_expand()

template<class Derived >
BatchTensor base_expand ( TorchShapeRef base_size) const

Return a new view of the tensor with values broadcast along the base dimensions.

◆ base_expand_as()

template<class Derived >
template<class Derived2 >
Derived2 base_expand_as ( const Derived2 & other) const

Expand the base to have the same shape as another tensor.

◆ base_expand_copy()

template<class Derived >
BatchTensor base_expand_copy ( TorchShapeRef base_size) const

Return a new tensor with values broadcast along the base dimensions.

◆ base_index()

template<class Derived >
BatchTensor base_index ( const TorchSlice & indices) const

Return an index sliced on the base dimensions.

◆ base_index_put()

template<class Derived >
void base_index_put ( const TorchSlice & indices,
const torch::Tensor & other )

Set a index sliced on the base dimensions to a value.

◆ base_movedim()

template<class Derived >
BatchTensor base_movedim ( TorchSize d1,
TorchSize d2 ) const

Move two base dimensions.

◆ base_reshape()

template<class Derived >
BatchTensor base_reshape ( TorchShapeRef base_shape) const

Reshape base dimensions.

◆ base_size()

template<class Derived >
TorchSize base_size ( TorchSize index) const

Return the length of some base axis.

◆ base_sizes()

template<class Derived >
TorchShapeRef base_sizes ( ) const

Return the base size.

◆ base_storage()

template<class Derived >
TorchSize base_storage ( ) const

Return the flattened storage needed just for the base indices.

◆ base_transpose()

template<class Derived >
BatchTensor base_transpose ( TorchSize d1,
TorchSize d2 ) const

Transpose two base dimensions.

◆ base_unsqueeze()

template<class Derived >
BatchTensor base_unsqueeze ( TorchSize d) const

Unsqueeze a base dimension.

◆ batch_dim() [1/2]

template<class Derived >
TorchSize & batch_dim ( )

Return a writable reference to the batch dimension.

◆ batch_dim() [2/2]

template<class Derived >
TorchSize batch_dim ( ) const

Return the number of batch dimensions.

◆ batch_expand()

template<class Derived >
Derived batch_expand ( TorchShapeRef batch_size) const

Return a new view of the tensor with values broadcast along the batch dimensions.

◆ batch_expand_as()

template<class Derived >
template<class Derived2 >
Derived batch_expand_as ( const Derived2 & other) const

Expand the batch to have the same shape as another tensor.

◆ batch_expand_copy()

template<class Derived >
Derived batch_expand_copy ( TorchShapeRef batch_size) const

Return a new tensor with values broadcast along the batch dimensions.

◆ batch_index()

template<class Derived >
Derived batch_index ( TorchSlice indices) const

Get a batch.

◆ batch_index_put()

template<class Derived >
void batch_index_put ( TorchSlice indices,
const torch::Tensor & other )

Set a index sliced on the batch dimensions to a value.

◆ batch_reshape()

template<class Derived >
Derived batch_reshape ( TorchShapeRef batch_shape) const

Reshape batch dimensions.

◆ batch_size()

template<class Derived >
TorchSize batch_size ( TorchSize index) const

Return the length of some batch axis.

◆ batch_sizes()

template<class Derived >
TorchShapeRef batch_sizes ( ) const

Return the batch size.

◆ batch_sum()

template<class Derived >
Derived batch_sum ( TorchSize d) const

Sum on a batch index.

◆ batch_transpose()

template<class Derived >
Derived batch_transpose ( TorchSize d1,
TorchSize d2 ) const

Transpose two batch dimensions.

◆ batch_unsqueeze()

template<class Derived >
Derived batch_unsqueeze ( TorchSize d) const

Unsqueeze a batch dimension.

◆ batched()

template<class Derived >
bool batched ( ) const

Whether the tensor is batched.

◆ clone()

template<class Derived >
Derived clone ( torch::MemoryFormat memory_format = torch::MemoryFormat::Contiguous) const

Clone (take ownership)

◆ detach()

template<class Derived >
Derived detach ( ) const

Discard function graph.

◆ empty_like()

template<class Derived >
Derived empty_like ( const Derived & other)
static

Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.

◆ full_like()

template<class Derived >
Derived full_like ( const Derived & other,
Real init )
static

Full tensor like another, i.e. same batch and base shapes, same tensor options, etc., but filled with a different value

◆ linspace()

template<class Derived >
Derived linspace ( const Derived & start,
const Derived & end,
TorchSize nstep,
TorchSize dim = 0,
TorchSize batch_dim = -1 )
static

Create a new tensor by adding a new batch dimension with linear spacing between start and end.

start and end must be broadcastable. The new batch dimension will be added at the user-specified dimension dim which defaults to 0.

For example, if start has shape (3, 2; 5, 5) and end has shape (3, 1; 5, 5), then

linspace(start, end, 100, 1);
static Derived linspace(const Derived &start, const Derived &end, TorchSize nstep, TorchSize dim=0, TorchSize batch_dim=-1)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition BatchTensorBase.cxx:80
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52

will have shape (3, 100, 2; 5, 5), note the location of the new dimension and the broadcasting.

Parameters
startThe starting tensor
endThe ending tensor
nstepThe number of steps with even spacing along the new dimension
dimWhere to insert the new dimension
batch_dimBatch dimension of the output
Returns
BatchTensor Linearly spaced tensor

◆ list_sum()

template<class Derived >
Derived list_sum ( ) const

Sum on the list index (TODO: replace with class)

◆ list_unsqueeze()

template<class Derived >
Derived list_unsqueeze ( ) const

Unsqueeze on the special list batch dimension.

◆ logspace()

template<class Derived >
Derived logspace ( const Derived & start,
const Derived & end,
TorchSize nstep,
TorchSize dim = 0,
TorchSize batch_dim = -1,
Real base = 10 )
static

log-space equivalent of the linspace named constructor

◆ ones_like()

template<class Derived >
Derived ones_like ( const Derived & other)
static

Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.

◆ operator-()

template<class Derived >
Derived operator- ( ) const

Negation.

◆ to()

template<class Derived >
Derived to ( const torch::TensorOptions & options) const

Send to options.

◆ zeros_like()

template<class Derived >
Derived zeros_like ( const Derived & other)
static

Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.