27#include "neml2/misc/utils.h"
32template <
class Derived>
44template <
class Derived>
108 using torch::Tensor::detach_;
110 Derived to(
const torch::TensorOptions & options)
const;
112 using torch::Tensor::copy_;
114 using torch::Tensor::zero_;
116 using torch::Tensor::requires_grad;
118 using torch::Tensor::requires_grad_;
127 using torch::Tensor::options;
129 using torch::Tensor::scalar_type;
131 using torch::Tensor::device;
133 using torch::Tensor::dim;
135 using torch::Tensor::sizes;
137 using torch::Tensor::size;
160 using torch::Tensor::index;
161 using torch::Tensor::index_put_;
179 template <
class Derived2>
182 template <
class Derived2>
207template <
class Derived>
208template <
class Derived2>
212 return batch_expand(
other.batch_sizes());
215template <
class Derived>
216template <
class Derived2>
220 return base_expand(
other.base_sizes());
224 typename =
typename std::enable_if<std::is_base_of_v<TensorBase<Derived>,
Derived>>>
228 return Derived(torch::operator+(a, b), a.batch_dim());
231template <
class Derived,
232 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
239template <
class Derived,
240 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
248template <
class Derived,
249 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
253 return Derived(torch::operator-(a, b), a.batch_dim());
256template <
class Derived,
257 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
264template <
class Derived,
265 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
273template <
class Derived,
274 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
278 return Derived(torch::operator*(a, b), a.batch_dim());
281template <
class Derived,
282 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
289template <
class Derived,
290 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
294 return Derived(torch::operator/(a, b), a.batch_dim());
297template <
class Derived,
298 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
302 return Derived(torch::operator/(a, b), b.batch_dim());
305template <
class Derived,
306 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
NEML2's enhanced tensor type.
Definition TensorBase.h:46
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Definition TensorBase.cxx:114
Derived batch_expand_copy(TensorShapeRef batch_size) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBase.cxx:250
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBase.cxx:191
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBase.cxx:303
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, Size batch_dim=-1, Real base=10)
log-space equivalent of the linspace named constructor
Definition TensorBase.cxx:105
Derived batch_expand_as(const Derived2 &other) const
Expand the batch to have the same shape as another tensor.
Definition TensorBase.h:210
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBase.cxx:184
TensorBase()=default
Default constructor.
bool batched() const
Whether the tensor is batched.
Definition TensorBase.cxx:135
Size batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBase.cxx:163
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBase.cxx:142
Derived detach() const
Discard function graph.
Definition TensorBase.cxx:121
Derived batch_expand(TensorShapeRef batch_size) const
Definition TensorBase.cxx:230
neml2::Tensor base_expand(TensorShapeRef base_size) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBase.cxx:240
Derived operator-() const
Negation.
Definition TensorBase.cxx:312
TensorShapeRef batch_sizes() const
Return the batch size.
Definition TensorBase.cxx:156
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBase.cxx:294
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBase.cxx:286
Derived batch_reshape(TensorShapeRef batch_shape) const
Reshape batch dimensions.
Definition TensorBase.cxx:264
neml2::Tensor base_expand_copy(TensorShapeRef base_size) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBase.cxx:257
Derived to(const torch::TensorOptions &options) const
Change tensor options.
Definition TensorBase.cxx:128
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBase.cxx:177
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBase.cxx:58
static Derived empty_like(const Derived &other)
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBase.cxx:51
Derived batch_unsqueeze(Size d) const
Unsqueeze a batch dimension.
Definition TensorBase.cxx:278
static Derived full_like(const Derived &other, Real init)
Definition TensorBase.cxx:72
void batch_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Set values by slicing on the batch dimensions.
Definition TensorBase.cxx:210
Size base_dim() const
Return the number of base dimensions.
Definition TensorBase.cxx:149
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBase.cxx:170
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, Size batch_dim=-1)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBase.cxx:79
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBase.cxx:271
void base_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Set values by slicing on the base dimensions.
Definition TensorBase.cxx:220
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBase.cxx:65
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBase.cxx:201
Derived2 base_expand_as(const Derived2 &other) const
Expand the base to have the same shape as another tensor.
Definition TensorBase.h:218
torch::ArrayRef< TensorIndex > TensorIndicesRef
Definition types.h:42
Definition CrossRef.cxx:30
Vec operator*(const Derived1 &A, const Derived2 &b)
matrix-vector product
Definition R2Base.cxx:233
Derived operator-(const Derived &a, const Scalar &b)
Definition Scalar.h:79
Derived operator+(const Derived &a, const Scalar &b)
Definition Scalar.h:58
Size broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
double Real
Definition types.h:31
int64_t Size
Definition types.h:33
Derived operator/(const Derived &a, const Scalar &b)
Definition Scalar.h:123
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
torch::IntArrayRef TensorShapeRef
Definition types.h:35