NEML2 1.4.0
Loading...
Searching...
No Matches
LabeledTensor3D Class Reference

A single-batched, logically 3D LabeledTensor. More...

Detailed Description

A single-batched, logically 3D LabeledTensor.

#include <LabeledTensor3D.h>

Inheritance diagram for LabeledTensor3D:

Public Member Functions

void fill (const LabeledTensor3D &other, bool recursive=true)
 
LabeledTensor3D chain (const LabeledTensor3D &other, const LabeledMatrix &dself, const LabeledMatrix &dother) const
 Second order chain rule product of two derivatives.
 
- Public Member Functions inherited from LabeledTensor< LabeledTensor3D, 3 >
 LabeledTensor ()=default
 Default constructor.
 
 LabeledTensor (const torch::Tensor &tensor, const std::array< const LabeledAxis *, D > &axes)
 Construct from a torch::Tensor and array of LabeledAxis
 
 LabeledTensor (const Tensor &tensor, const std::array< const LabeledAxis *, D > &axes)
 Construct from a Tensor with array of LabeledAxis
 
 LabeledTensor (const LabeledTensor3D &other)
 Copy constructor.
 
LabeledTensor< LabeledTensor3D, D > & operator= (const LabeledTensor3D &other)
 Assignment operator.
 
reinterpret (indexing::TensorLabelsRef indices) const
 Get a tensor by slicing on the base dimensions AND reinterpret it as a primitive tensor.
 
const std::array< const LabeledAxis *, D > & axes () const
 Get all the labeled axes.
 
const LabeledAxisaxis (Size i=0) const
 Get a specific labeled axis.
 
 operator Tensor () const
 
 operator torch::Tensor () const
 
const Tensortensor () const
 
Tensortensor ()
 
LabeledTensor3D clone (torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
 
LabeledTensor3D detach () const
 Return a copy without gradient graphs.
 
void detach_ ()
 Detach from gradient graphs.
 
LabeledTensor3D to (const torch::TensorOptions &options) const
 Change tensor options.
 
void copy_ (const torch::Tensor &other)
 Copy another tensor.
 
void zero_ ()
 Set all entries to zero.
 
bool requires_grad () const
 Get the requires_grad property.
 
void requires_grad_ (bool req=true)
 Set the requires_grad property.
 
LabeledTensor3D operator- () const
 Negation.
 
torch::TensorOptions options () const
 
torch::Dtype scalar_type () const
 Tensor options.
 
torch::Device device () const
 Tensor options.
 
Size dim () const
 Number of tensor dimensions.
 
TensorShapeRef sizes () const
 Tensor shape.
 
Size size (Size dim) const
 Tensor shape.
 
bool batched () const
 Whether the tensor is batched.
 
Size batch_dim () const
 Return the number of batch dimensions.
 
constexpr Size base_dim () const
 Return the number of base dimensions.
 
TensorShapeRef batch_sizes () const
 Return the batch size.
 
Size batch_size (Size d) const
 Return the length of some batch axis.
 
TensorShapeRef base_sizes () const
 Return the base size.
 
Size base_size (Size d) const
 Return the length of some base axis.
 
Size base_storage () const
 Return the flattened storage needed just for the base indices.
 
LabeledTensor3D batch_index (indexing::TensorIndicesRef indices) const
 
Tensor base_index (indexing::TensorLabelsRef labels) const
 Get a tensor by slicing on the base dimensions.
 
void batch_index_put_ (indexing::TensorIndicesRef indices, const torch::Tensor &other)
 Set values by slicing on the batch dimensions.
 
void base_index_put_ (indexing::TensorLabelsRef labels, const Tensor &other)
 Set values by slicing on the base dimensions.
 

Additional Inherited Members

- Static Public Member Functions inherited from LabeledTensor< LabeledTensor3D, 3 >
static LabeledTensor3D empty (TensorShapeRef batch_shape, const std::array< const LabeledAxis *, D > &axes, const torch::TensorOptions &options=default_tensor_options())
 Setup new empty storage.
 
static LabeledTensor3D zeros (TensorShapeRef batch_shape, const std::array< const LabeledAxis *, D > &axes, const torch::TensorOptions &options=default_tensor_options())
 Setup new storage with zeros.
 
- Protected Attributes inherited from LabeledTensor< LabeledTensor3D, 3 >
Tensor _tensor
 The tensor.
 
std::array< const LabeledAxis *, D_axes
 The labeled axes of this tensor.
 

Member Function Documentation

◆ chain()

LabeledTensor3D chain ( const LabeledTensor3D & other,
const LabeledMatrix & dself,
const LabeledMatrix & dother ) const

Second order chain rule product of two derivatives.

◆ fill()

void fill ( const LabeledTensor3D & other,
bool recursive = true )

Fill another tensor into this tensor. The item set of the other tensor must be a subset of this tensor's item set.