27#include "neml2/tensors/BatchTensor.h"
104 "Base shape mismatch: trying to create a tensor with base shape ",
106 " from a tensor with base shape ",
115 "Base shape mismatch: trying to create a tensor with base shape ",
117 " from a tensor with shape ",
131 return Derived(torch::empty(const_base_sizes, options), 0);
137 const torch::TensorOptions & options)
147 return Derived(torch::zeros(const_base_sizes, options), 0);
153 const torch::TensorOptions & options)
163 return Derived(torch::ones(const_base_sizes, options), 0);
178 return Derived(torch::full(const_base_sizes,
init, options), 0);
185 const torch::TensorOptions & options)
NEML2's enhanced tensor type.
Definition BatchTensorBase.h:46
TorchSize batch_dim() const
Return the number of batch dimensions.
Definition BatchTensorBase.cxx:128
TorchShapeRef base_sizes() const
Return the base size.
Definition BatchTensorBase.cxx:163
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
FixedDimTensor inherits from BatchTensorBase and additionally templates on the base shape.
Definition FixedDimTensor.h:38
static Derived full(Real init, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition FixedDimTensor.h:176
static Derived full(TorchShapeRef batch_shape, Real init, const torch::TensorOptions &options=default_tensor_options())
Full tensor given batch shape.
Definition FixedDimTensor.h:183
static BatchTensor identity_map(const torch::TensorOptions &)
Derived tensor classes should define identity_map where appropriate.
Definition FixedDimTensor.h:89
static constexpr TorchSize const_base_dim
The base dim.
Definition FixedDimTensor.h:44
static Derived ones(TorchShapeRef batch_shape, const torch::TensorOptions &options=default_tensor_options())
Unit tensor given batch shape.
Definition FixedDimTensor.h:168
FixedDimTensor()=default
Default constructor.
static const TorchSize const_base_storage
The base storage.
Definition FixedDimTensor.h:47
static Derived empty(const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor.
Definition FixedDimTensor.h:129
FixedDimTensor(const torch::Tensor &tensor)
Construct from another torch::Tensor and infer batch dimension.
Definition FixedDimTensor.h:111
FixedDimTensor(const torch::Tensor &tensor, TorchSize batch_dim)
Construct from another torch::Tensor given batch dimension.
Definition FixedDimTensor.h:100
static Derived zeros(const torch::TensorOptions &options=default_tensor_options())
Unbatched zero tensor.
Definition FixedDimTensor.h:145
static Derived ones(const torch::TensorOptions &options=default_tensor_options())
Unbatched unit tensor.
Definition FixedDimTensor.h:161
static Derived zeros(TorchShapeRef batch_shape, const torch::TensorOptions &options=default_tensor_options())
Zero tensor given batch shape.
Definition FixedDimTensor.h:152
static const TorchShape const_base_sizes
The base shape.
Definition FixedDimTensor.h:41
static Derived empty(TorchShapeRef batch_shape, const torch::TensorOptions &options=default_tensor_options())
Empty tensor given batch shape.
Definition FixedDimTensor.h:136
TorchSize storage_size(TorchShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:32
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
const torch::TensorOptions default_tensor_options()
Definition types.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
int64_t TorchSize
Definition types.h:35
std::vector< TorchSize > TorchShape
Definition types.h:36
double Real
Definition types.h:33
torch::IntArrayRef TorchShapeRef
Definition types.h:37