27#include "neml2/tensors/Tensor.h"
36template <
class Derived,
Size... S>
101template <
class Derived,
Size... S>
106 "Base shape mismatch: trying to create a tensor with base shape ",
108 " from a tensor with base shape ",
117 "Base shape mismatch: trying to create a tensor with base shape ",
119 " from a tensor with shape ",
126 return Tensor(*
this, this->batch_dim());
133 return Derived(torch::empty(const_base_sizes, options), 0);
139 const torch::TensorOptions & options)
149 return Derived(torch::zeros(const_base_sizes, options), 0);
155 const torch::TensorOptions & options)
165 return Derived(torch::ones(const_base_sizes, options), 0);
171 const torch::TensorOptions & options)
181 return Derived(torch::full(const_base_sizes,
init, options), 0);
188 const torch::TensorOptions & options)
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
PrimitiveTensor inherits from TensorBase and additionally templates on the base shape.
Definition PrimitiveTensor.h:38
static Derived full(Real init, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition PrimitiveTensor.h:179
static Derived empty(TensorShapeRef batch_shape, const torch::TensorOptions &options=default_tensor_options())
Empty tensor given batch shape.
Definition PrimitiveTensor.h:138
PrimitiveTensor(const torch::Tensor &tensor)
Construct from another torch::Tensor and infer batch dimension.
Definition PrimitiveTensor.h:113
PrimitiveTensor(const torch::Tensor &tensor, Size batch_dim)
Construct from another torch::Tensor given batch dimension.
Definition PrimitiveTensor.h:102
static Derived empty(const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor.
Definition PrimitiveTensor.h:131
static Derived full(TensorShapeRef batch_shape, Real init, const torch::TensorOptions &options=default_tensor_options())
Full tensor given batch shape.
Definition PrimitiveTensor.h:186
static const TensorShape const_base_sizes
The base shape.
Definition PrimitiveTensor.h:41
static Derived ones(TensorShapeRef batch_shape, const torch::TensorOptions &options=default_tensor_options())
Unit tensor given batch shape.
Definition PrimitiveTensor.h:170
static Derived zeros(const torch::TensorOptions &options=default_tensor_options())
Unbatched zero tensor.
Definition PrimitiveTensor.h:147
static constexpr Size const_base_dim
The base dim.
Definition PrimitiveTensor.h:44
static Derived ones(const torch::TensorOptions &options=default_tensor_options())
Unbatched unit tensor.
Definition PrimitiveTensor.h:163
static const Size const_base_storage
The base storage.
Definition PrimitiveTensor.h:47
static Derived zeros(TensorShapeRef batch_shape, const torch::TensorOptions &options=default_tensor_options())
Zero tensor given batch shape.
Definition PrimitiveTensor.h:154
static Tensor identity_map(const torch::TensorOptions &)
Derived tensor classes should define identity_map where appropriate.
Definition PrimitiveTensor.h:91
PrimitiveTensor()=default
Default constructor.
NEML2's enhanced tensor type.
Definition TensorBase.h:46
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBase.cxx:142
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBase.cxx:170
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:40
TensorShape add_shapes(S &&... shape)
Definition utils.h:298
Definition CrossRef.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:76
torch::TensorOptions & default_tensor_options()
Definition types.cxx:30
double Real
Definition types.h:31
torch::SmallVector< Size > TensorShape
Definition types.h:34
int64_t Size
Definition types.h:33
torch::IntArrayRef TensorShapeRef
Definition types.h:35