NEML2 1.4.0
Loading...
Searching...
No Matches
LabeledMatrix Class Reference

A single-batched, logically 2D LabeledTensor. More...

Detailed Description

A single-batched, logically 2D LabeledTensor.

#include <LabeledMatrix.h>

Inheritance diagram for LabeledMatrix:

Public Member Functions

void fill (const LabeledMatrix &other, bool recursive=true)
 
LabeledMatrix chain (const LabeledMatrix &other) const
 Chain rule product of two derivatives.
 
- Public Member Functions inherited from LabeledTensor< LabeledMatrix, 2 >
 LabeledTensor ()=default
 Default constructor.
 
 LabeledTensor (const torch::Tensor &tensor, const std::array< const LabeledAxis *, D > &axes)
 Construct from a torch::Tensor and array of LabeledAxis
 
 LabeledTensor (const Tensor &tensor, const std::array< const LabeledAxis *, D > &axes)
 Construct from a Tensor with array of LabeledAxis
 
 LabeledTensor (const LabeledMatrix &other)
 Copy constructor.
 
LabeledTensor< LabeledMatrix, D > & operator= (const LabeledMatrix &other)
 Assignment operator.
 
reinterpret (indexing::TensorLabelsRef indices) const
 Get a tensor by slicing on the base dimensions AND reinterpret it as a primitive tensor.
 
const std::array< const LabeledAxis *, D > & axes () const
 Get all the labeled axes.
 
const LabeledAxisaxis (Size i=0) const
 Get a specific labeled axis.
 
 operator Tensor () const
 
 operator torch::Tensor () const
 
const Tensortensor () const
 
Tensortensor ()
 
LabeledMatrix clone (torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
 
LabeledMatrix detach () const
 Return a copy without gradient graphs.
 
void detach_ ()
 Detach from gradient graphs.
 
LabeledMatrix to (const torch::TensorOptions &options) const
 Change tensor options.
 
void copy_ (const torch::Tensor &other)
 Copy another tensor.
 
void zero_ ()
 Set all entries to zero.
 
bool requires_grad () const
 Get the requires_grad property.
 
void requires_grad_ (bool req=true)
 Set the requires_grad property.
 
LabeledMatrix operator- () const
 Negation.
 
torch::TensorOptions options () const
 
torch::Dtype scalar_type () const
 Tensor options.
 
torch::Device device () const
 Tensor options.
 
Size dim () const
 Number of tensor dimensions.
 
TensorShapeRef sizes () const
 Tensor shape.
 
Size size (Size dim) const
 Tensor shape.
 
bool batched () const
 Whether the tensor is batched.
 
Size batch_dim () const
 Return the number of batch dimensions.
 
constexpr Size base_dim () const
 Return the number of base dimensions.
 
TensorShapeRef batch_sizes () const
 Return the batch size.
 
Size batch_size (Size d) const
 Return the length of some batch axis.
 
TensorShapeRef base_sizes () const
 Return the base size.
 
Size base_size (Size d) const
 Return the length of some base axis.
 
Size base_storage () const
 Return the flattened storage needed just for the base indices.
 
LabeledMatrix batch_index (indexing::TensorIndicesRef indices) const
 
Tensor base_index (indexing::TensorLabelsRef labels) const
 Get a tensor by slicing on the base dimensions.
 
void batch_index_put_ (indexing::TensorIndicesRef indices, const torch::Tensor &other)
 Set values by slicing on the batch dimensions.
 
void base_index_put_ (indexing::TensorLabelsRef labels, const Tensor &other)
 Set values by slicing on the base dimensions.
 

Static Public Member Functions

static LabeledMatrix identity (TensorShapeRef batch_size, const LabeledAxis &axis, const torch::TensorOptions &options=default_tensor_options())
 Create a labeled identity tensor.
 
- Static Public Member Functions inherited from LabeledTensor< LabeledMatrix, 2 >
static LabeledMatrix empty (TensorShapeRef batch_shape, const std::array< const LabeledAxis *, D > &axes, const torch::TensorOptions &options=default_tensor_options())
 Setup new empty storage.
 
static LabeledMatrix zeros (TensorShapeRef batch_shape, const std::array< const LabeledAxis *, D > &axes, const torch::TensorOptions &options=default_tensor_options())
 Setup new storage with zeros.
 

Additional Inherited Members

- Protected Attributes inherited from LabeledTensor< LabeledMatrix, 2 >
Tensor _tensor
 The tensor.
 
std::array< const LabeledAxis *, D_axes
 The labeled axes of this tensor.
 

Member Function Documentation

◆ chain()

LabeledMatrix chain ( const LabeledMatrix & other) const

Chain rule product of two derivatives.

◆ fill()

void fill ( const LabeledMatrix & other,
bool recursive = true )

Fill another matrix into this matrix. The item set of the other matrix must be a subset of this matrix's item set.

◆ identity()

LabeledMatrix identity ( TensorShapeRef batch_size,
const LabeledAxis & axis,
const torch::TensorOptions & options = default_tensor_options() )
static

Create a labeled identity tensor.