NEML2 1.4.0
|
NEML2's enhanced tensor type. More...
NEML2's enhanced tensor type.
BatchTensorBase derives from torch::Tensor and clearly distinguishes between "batched" dimensions from other dimensions. The shape of the "batched" dimensions is called the batch size, and the shape of the rest dimensions is called the base size.
#include <BatchTensorBase.h>
Public Member Functions | |
BatchTensorBase ()=default | |
Default constructor. | |
BatchTensorBase (const torch::Tensor &tensor, TorchSize batch_dim) | |
Construct from another torch::Tensor. | |
BatchTensorBase (const Derived &tensor) | |
Copy constructor. | |
BatchTensorBase (Real)=delete | |
bool | batched () const |
Whether the tensor is batched. | |
TorchSize | batch_dim () const |
Return the number of batch dimensions. | |
TorchSize & | batch_dim () |
Return a writable reference to the batch dimension. | |
TorchSize | base_dim () const |
Return the number of base dimensions. | |
TorchShapeRef | batch_sizes () const |
Return the batch size. | |
TorchSize | batch_size (TorchSize index) const |
Return the length of some batch axis. | |
TorchShapeRef | base_sizes () const |
Return the base size. | |
TorchSize | base_size (TorchSize index) const |
Return the length of some base axis. | |
TorchSize | base_storage () const |
Return the flattened storage needed just for the base indices. | |
Derived | batch_index (TorchSlice indices) const |
Get a batch. | |
BatchTensor | base_index (const TorchSlice &indices) const |
Return an index sliced on the base dimensions. | |
void | batch_index_put (TorchSlice indices, const torch::Tensor &other) |
Set a index sliced on the batch dimensions to a value. | |
void | base_index_put (const TorchSlice &indices, const torch::Tensor &other) |
Set a index sliced on the base dimensions to a value. | |
Derived | batch_expand (TorchShapeRef batch_size) const |
Return a new view of the tensor with values broadcast along the batch dimensions. | |
BatchTensor | base_expand (TorchShapeRef base_size) const |
Return a new view of the tensor with values broadcast along the base dimensions. | |
template<class Derived2 > | |
Derived | batch_expand_as (const Derived2 &other) const |
Expand the batch to have the same shape as another tensor. | |
template<class Derived2 > | |
Derived2 | base_expand_as (const Derived2 &other) const |
Expand the base to have the same shape as another tensor. | |
Derived | batch_expand_copy (TorchShapeRef batch_size) const |
Return a new tensor with values broadcast along the batch dimensions. | |
BatchTensor | base_expand_copy (TorchShapeRef base_size) const |
Return a new tensor with values broadcast along the base dimensions. | |
Derived | batch_reshape (TorchShapeRef batch_shape) const |
Reshape batch dimensions. | |
BatchTensor | base_reshape (TorchShapeRef base_shape) const |
Reshape base dimensions. | |
Derived | batch_unsqueeze (TorchSize d) const |
Unsqueeze a batch dimension. | |
Derived | list_unsqueeze () const |
Unsqueeze on the special list batch dimension. | |
BatchTensor | base_unsqueeze (TorchSize d) const |
Unsqueeze a base dimension. | |
Derived | batch_transpose (TorchSize d1, TorchSize d2) const |
Transpose two batch dimensions. | |
BatchTensor | base_transpose (TorchSize d1, TorchSize d2) const |
Transpose two base dimensions. | |
BatchTensor | base_movedim (TorchSize d1, TorchSize d2) const |
Move two base dimensions. | |
Derived | clone (torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const |
Clone (take ownership) | |
Derived | detach () const |
Discard function graph. | |
Derived | to (const torch::TensorOptions &options) const |
Send to options. | |
Derived | operator- () const |
Negation. | |
Derived | batch_sum (TorchSize d) const |
Sum on a batch index. | |
Derived | list_sum () const |
Sum on the list index (TODO: replace with class) | |
Static Public Member Functions | |
static Derived | empty_like (const Derived &other) |
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc. | |
static Derived | zeros_like (const Derived &other) |
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc. | |
static Derived | ones_like (const Derived &other) |
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc. | |
static Derived | full_like (const Derived &other, Real init) |
static Derived | linspace (const Derived &start, const Derived &end, TorchSize nstep, TorchSize dim=0, TorchSize batch_dim=-1) |
Create a new tensor by adding a new batch dimension with linear spacing between start and end . | |
static Derived | logspace (const Derived &start, const Derived &end, TorchSize nstep, TorchSize dim=0, TorchSize batch_dim=-1, Real base=10) |
log-space equivalent of the linspace named constructor | |
|
default |
Default constructor.
BatchTensorBase | ( | const torch::Tensor & | tensor, |
TorchSize | batch_dim ) |
Construct from another torch::Tensor.
BatchTensorBase | ( | const Derived & | tensor | ) |
Copy constructor.
|
delete |
BatchTensor base_expand | ( | TorchShapeRef | base_size | ) | const |
Return a new view of the tensor with values broadcast along the base dimensions.
Expand the base to have the same shape as another tensor.
BatchTensor base_expand_copy | ( | TorchShapeRef | base_size | ) | const |
Return a new tensor with values broadcast along the base dimensions.
BatchTensor base_index | ( | const TorchSlice & | indices | ) | const |
Return an index sliced on the base dimensions.
void base_index_put | ( | const TorchSlice & | indices, |
const torch::Tensor & | other ) |
Set a index sliced on the base dimensions to a value.
BatchTensor base_movedim | ( | TorchSize | d1, |
TorchSize | d2 ) const |
Move two base dimensions.
BatchTensor base_reshape | ( | TorchShapeRef | base_shape | ) | const |
Reshape base dimensions.
Return the length of some base axis.
TorchShapeRef base_sizes | ( | ) | const |
Return the base size.
Return the flattened storage needed just for the base indices.
BatchTensor base_transpose | ( | TorchSize | d1, |
TorchSize | d2 ) const |
Transpose two base dimensions.
BatchTensor base_unsqueeze | ( | TorchSize | d | ) | const |
Unsqueeze a base dimension.
Return a writable reference to the batch dimension.
Derived batch_expand | ( | TorchShapeRef | batch_size | ) | const |
Return a new view of the tensor with values broadcast along the batch dimensions.
Expand the batch to have the same shape as another tensor.
Derived batch_expand_copy | ( | TorchShapeRef | batch_size | ) | const |
Return a new tensor with values broadcast along the batch dimensions.
Derived batch_index | ( | TorchSlice | indices | ) | const |
Get a batch.
void batch_index_put | ( | TorchSlice | indices, |
const torch::Tensor & | other ) |
Set a index sliced on the batch dimensions to a value.
Derived batch_reshape | ( | TorchShapeRef | batch_shape | ) | const |
Reshape batch dimensions.
Return the length of some batch axis.
TorchShapeRef batch_sizes | ( | ) | const |
Return the batch size.
Transpose two batch dimensions.
Derived clone | ( | torch::MemoryFormat | memory_format = torch::MemoryFormat::Contiguous | ) | const |
Clone (take ownership)
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Full tensor like another, i.e. same batch and base shapes, same tensor options, etc., but filled with a different value
|
static |
Create a new tensor by adding a new batch dimension with linear spacing between start
and end
.
start
and end
must be broadcastable. The new batch dimension will be added at the user-specified dimension dim
which defaults to 0.
For example, if start
has shape (3, 2; 5, 5)
and end
has shape (3, 1; 5, 5)
, then
will have shape (3, 100, 2; 5, 5)
, note the location of the new dimension and the broadcasting.
start | The starting tensor |
end | The ending tensor |
nstep | The number of steps with even spacing along the new dimension |
dim | Where to insert the new dimension |
batch_dim | Batch dimension of the output |
Sum on the list index (TODO: replace with class)
Unsqueeze on the special list batch dimension.
|
static |
log-space equivalent of the linspace named constructor
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.