27#include "neml2/misc/utils.h"
32template <
class Derived>
44template <
class Derived>
149 template <
class Derived2>
153 template <
class Derived2>
193 Derived to(
const torch::TensorOptions & options)
const;
209template <
class Derived>
210template <
class Derived2>
214 return batch_expand(
other.batch_sizes());
217template <
class Derived>
218template <
class Derived2>
222 return base_expand(
other.base_sizes());
226 typename =
typename std::enable_if<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
230 return Derived(torch::operator+(a, b), a.batch_dim());
235 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
244 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
254 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
258 return Derived(torch::operator-(a, b), a.batch_dim());
263 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
272 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
282 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
286 return Derived(torch::operator*(a, b), a.batch_dim());
291 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
300 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
304 return Derived(torch::operator/(a, b), a.batch_dim());
309 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
313 return Derived(torch::operator/(a, b), b.batch_dim());
318 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
330 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
334 return Derived(torch::pow(a,
n), a.batch_dim());
339 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
343 return Derived(torch::pow(a,
n),
n.batch_dim());
348 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
358 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
362 return Derived(torch::sign(a), a.batch_dim());
367 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
371 return Derived(torch::cosh(a), a.batch_dim());
376 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
380 return Derived(torch::sinh(a), a.batch_dim());
385 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
389 return Derived(torch::tanh(a), a.batch_dim());
394 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
410 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
414 return (
sign(a) + 1.0) / 2.0;
419 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
423 return Derived(torch::Tensor(a) * torch::Tensor(
heaviside(a)), a.batch_dim());
428 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
437 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
441 return Derived(torch::sqrt(a), a.batch_dim());
446 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
450 return Derived(torch::exp(a), a.batch_dim());
455 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
459 return Derived(torch::abs(a), a.batch_dim());
464 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
468 return Derived(torch::diff(a,
n,
dim), a.batch_dim());
473 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
477 return Derived(torch::diag_embed(
484 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
488 return Derived(torch::log(a), a.batch_dim());
NEML2's enhanced tensor type.
Definition BatchTensorBase.h:46
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Clone (take ownership)
Definition BatchTensorBase.cxx:317
TorchSize base_storage() const
Return the flattened storage needed just for the base indices.
Definition BatchTensorBase.cxx:177
TorchShapeRef batch_sizes() const
Return the batch size.
Definition BatchTensorBase.cxx:149
Derived batch_expand_as(const Derived2 &other) const
Expand the batch to have the same shape as another tensor.
Definition BatchTensorBase.h:212
Derived list_sum() const
Sum on the list index (TODO: replace with class)
Definition BatchTensorBase.cxx:354
Derived batch_transpose(TorchSize d1, TorchSize d2) const
Transpose two batch dimensions.
Definition BatchTensorBase.cxx:290
bool batched() const
Whether the tensor is batched.
Definition BatchTensorBase.cxx:121
BatchTensor base_transpose(TorchSize d1, TorchSize d2) const
Transpose two base dimensions.
Definition BatchTensorBase.cxx:299
static Derived linspace(const Derived &start, const Derived &end, TorchSize nstep, TorchSize dim=0, TorchSize batch_dim=-1)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition BatchTensorBase.cxx:80
Derived detach() const
Discard function graph.
Definition BatchTensorBase.cxx:324
Derived operator-() const
Negation.
Definition BatchTensorBase.cxx:338
TorchSize batch_dim() const
Return the number of batch dimensions.
Definition BatchTensorBase.cxx:128
static Derived logspace(const Derived &start, const Derived &end, TorchSize nstep, TorchSize dim=0, TorchSize batch_dim=-1, Real base=10)
log-space equivalent of the linspace named constructor
Definition BatchTensorBase.cxx:108
Derived batch_index(TorchSlice indices) const
Get a batch.
Definition BatchTensorBase.cxx:184
BatchTensor base_expand(TorchShapeRef base_size) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition BatchTensorBase.cxx:229
TorchShapeRef base_sizes() const
Return the base size.
Definition BatchTensorBase.cxx:163
BatchTensor base_index(const TorchSlice &indices) const
Return an index sliced on the base dimensions.
Definition BatchTensorBase.cxx:193
Derived to(const torch::TensorOptions &options) const
Send to options.
Definition BatchTensorBase.cxx:331
TorchSize batch_size(TorchSize index) const
Return the length of some batch axis.
Definition BatchTensorBase.cxx:156
Derived batch_expand(TorchShapeRef batch_size) const
Return a new view of the tensor with values broadcast along the batch dimensions.
Definition BatchTensorBase.cxx:219
Derived batch_unsqueeze(TorchSize d) const
Unsqueeze a batch dimension.
Definition BatchTensorBase.cxx:267
void base_index_put(const TorchSlice &indices, const torch::Tensor &other)
Set a index sliced on the base dimensions to a value.
Definition BatchTensorBase.cxx:210
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:59
BatchTensor base_reshape(TorchShapeRef base_shape) const
Reshape base dimensions.
Definition BatchTensorBase.cxx:260
static Derived empty_like(const Derived &other)
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:52
TorchSize base_dim() const
Return the number of base dimensions.
Definition BatchTensorBase.cxx:142
BatchTensor base_movedim(TorchSize d1, TorchSize d2) const
Move two base dimensions.
Definition BatchTensorBase.cxx:308
BatchTensor base_unsqueeze(TorchSize d) const
Unsqueeze a base dimension.
Definition BatchTensorBase.cxx:282
static Derived full_like(const Derived &other, Real init)
Definition BatchTensorBase.cxx:73
BatchTensorBase(Real)=delete
BatchTensor base_expand_copy(TorchShapeRef base_size) const
Return a new tensor with values broadcast along the base dimensions.
Definition BatchTensorBase.cxx:246
Derived list_unsqueeze() const
Unsqueeze on the special list batch dimension.
Definition BatchTensorBase.cxx:275
TorchSize base_size(TorchSize index) const
Return the length of some base axis.
Definition BatchTensorBase.cxx:170
void batch_index_put(TorchSlice indices, const torch::Tensor &other)
Set a index sliced on the batch dimensions to a value.
Definition BatchTensorBase.cxx:202
Derived batch_expand_copy(TorchShapeRef batch_size) const
Return a new tensor with values broadcast along the batch dimensions.
Definition BatchTensorBase.cxx:239
BatchTensorBase()=default
Default constructor.
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:66
Derived batch_reshape(TorchShapeRef batch_shape) const
Reshape batch dimensions.
Definition BatchTensorBase.cxx:253
Derived2 base_expand_as(const Derived2 &other) const
Expand the base to have the same shape as another tensor.
Definition BatchTensorBase.h:220
Derived batch_sum(TorchSize d) const
Sum on a batch index.
Definition BatchTensorBase.cxx:345
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
Derived macaulay(const Derived &a)
Definition BatchTensorBase.h:421
Derived sinh(const Derived &a)
Definition BatchTensorBase.h:378
Derived heaviside(const Derived &a)
Definition BatchTensorBase.h:412
Derived tanh(const Derived &a)
Definition BatchTensorBase.h:387
Derived cosh(const Derived &a)
Definition BatchTensorBase.h:369
Derived diff(const Derived &a, TorchSize n=1, TorchSize dim=-1)
Definition BatchTensorBase.h:466
Derived where(const torch::Tensor &condition, const Derived &a, const Derived &b)
Definition BatchTensorBase.h:396
Derived log(const Derived &a)
Definition BatchTensorBase.h:486
Derived abs(const Derived &a)
Definition BatchTensorBase.h:457
Derived dmacaulay(const Derived &a)
Definition BatchTensorBase.h:430
Derived exp(const Derived &a)
Definition BatchTensorBase.h:448
Derived batch_diag_embed(const Derived &a, TorchSize offset=0, TorchSize d1=-2, TorchSize d2=-1)
Definition BatchTensorBase.h:475
Derived sign(const Derived &a)
Definition BatchTensorBase.h:360
Derived sqrt(const Derived &a)
Definition BatchTensorBase.h:439
Derived pow(const Derived &a, const Real &n)
Definition BatchTensorBase.h:332
Definition CrossRef.cxx:32
Derived operator-(const Derived &a, const Real &b)
Definition BatchTensorBase.h:256
BatchTensor operator*(const BatchTensor &a, const BatchTensor &b)
Definition BatchTensor.cxx:153
int64_t TorchSize
Definition types.h:35
TorchSize broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
double Real
Definition types.h:33
Derived operator+(const Derived &a, const Real &b)
Definition BatchTensorBase.h:228
torch::IntArrayRef TorchShapeRef
Definition types.h:37
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
std::vector< at::indexing::TensorIndex > TorchSlice
Definition types.h:39
Derived operator/(const Derived &a, const Real &b)
Definition BatchTensorBase.h:302