25#include "neml2/tensors/BatchTensorBase.h"
26#include "neml2/tensors/tensors.h"
27#include "neml2/tensors/macros.h"
31template <
class Derived>
33 : torch::Tensor(tensor),
39 " is smaller than the requested number of batch dimensions ",
43template <
class Derived>
45 : torch::Tensor(tensor),
46 _batch_dim(tensor.batch_dim())
50template <
class Derived>
57template <
class Derived>
64template <
class Derived>
71template <
class Derived>
78template <
class Derived>
86 using namespace torch::indexing;
93 auto diff = (end -
start).batch_unsqueeze(
dim);
103 return Derived(
res, batch_dim >= 0 ? batch_dim :
res.batch_dim());
106template <
class Derived>
119template <
class Derived>
126template <
class Derived>
133template <
class Derived>
140template <
class Derived>
144 return dim() - batch_dim();
147template <
class Derived>
151 return sizes().slice(0, _batch_dim);
154template <
class Derived>
158 return batch_sizes()[index >= 0 ? index : index + batch_dim()];
161template <
class Derived>
165 return sizes().slice(_batch_dim);
168template <
class Derived>
172 return base_sizes()[index >= 0 ? index : index + base_dim()];
175template <
class Derived>
182template <
class Derived>
186 indices.insert(indices.end(), base_dim(), torch::indexing::Slice());
187 auto res = this->index(indices);
191template <
class Derived>
200template <
class Derived>
204 indices.insert(indices.end(), base_dim(), torch::indexing::Slice());
208template <
class Derived>
217template <
class Derived>
222 auto net = batch_size.vec();
223 net.insert(
net.end(), base_dim(), -1);
227template <
class Derived>
232 auto net = base_size.vec();
233 net.insert(
net.begin(), batch_dim(), -1);
237template <
class Derived>
244template <
class Derived>
251template <
class Derived>
258template <
class Derived>
265template <
class Derived>
269 auto d2 = d >= 0 ? d : d - base_dim();
273template <
class Derived>
277 return batch_unsqueeze(-1);
280template <
class Derived>
284 auto d2 = d < 0 ? d : d + batch_dim();
285 return BatchTensor(torch::Tensor::unsqueeze(
d2), batch_dim());
288template <
class Derived>
293 torch::Tensor::transpose(
d1 < 0 ?
d1 - base_dim() :
d1,
d2 < 0 ?
d2 - base_dim() :
d2),
297template <
class Derived>
302 torch::Tensor::transpose(
d1 < 0 ?
d1 : _batch_dim +
d1,
d2 < 0 ?
d2 : _batch_dim +
d2),
306template <
class Derived>
311 torch::Tensor::movedim(
d1 < 0 ?
d1 : _batch_dim +
d1,
d2 < 0 ?
d2 : _batch_dim +
d2),
315template <
class Derived>
322template <
class Derived>
326 return Derived(torch::Tensor::detach(), _batch_dim);
329template <
class Derived>
333 return Derived(torch::Tensor::to(options), _batch_dim);
336template <
class Derived>
340 return Derived(-torch::Tensor(*
this), _batch_dim);
343template <
class Derived>
347 neml_assert_dbg(_batch_dim > 0,
"Must have a batch dimension to sum along");
348 auto d2 = d >= 0 ? d : d - base_dim();
349 return Derived(torch::sum(*
this,
d2), _batch_dim - 1);
352template <
class Derived>
356 return batch_sum(-1);
359#define BATCHTENSORBASE_INSTANTIATE(T) template class BatchTensorBase<T>
360FOR_ALL_BATCHTENSORBASE(BATCHTENSORBASE_INSTANTIATE);
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Clone (take ownership)
Definition BatchTensorBase.cxx:317
TorchSize base_storage() const
Return the flattened storage needed just for the base indices.
Definition BatchTensorBase.cxx:177
TorchShapeRef batch_sizes() const
Return the batch size.
Definition BatchTensorBase.cxx:149
Derived list_sum() const
Sum on the list index (TODO: replace with class)
Definition BatchTensorBase.cxx:354
Derived batch_transpose(TorchSize d1, TorchSize d2) const
Transpose two batch dimensions.
Definition BatchTensorBase.cxx:290
bool batched() const
Whether the tensor is batched.
Definition BatchTensorBase.cxx:121
BatchTensor base_transpose(TorchSize d1, TorchSize d2) const
Transpose two base dimensions.
Definition BatchTensorBase.cxx:299
static Derived linspace(const Derived &start, const Derived &end, TorchSize nstep, TorchSize dim=0, TorchSize batch_dim=-1)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition BatchTensorBase.cxx:80
Derived detach() const
Discard function graph.
Definition BatchTensorBase.cxx:324
Derived operator-() const
Negation.
Definition BatchTensorBase.cxx:338
TorchSize batch_dim() const
Return the number of batch dimensions.
Definition BatchTensorBase.cxx:128
static Derived logspace(const Derived &start, const Derived &end, TorchSize nstep, TorchSize dim=0, TorchSize batch_dim=-1, Real base=10)
log-space equivalent of the linspace named constructor
Definition BatchTensorBase.cxx:108
Derived batch_index(TorchSlice indices) const
Get a batch.
Definition BatchTensorBase.cxx:184
BatchTensor base_expand(TorchShapeRef base_size) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition BatchTensorBase.cxx:229
TorchShapeRef base_sizes() const
Return the base size.
Definition BatchTensorBase.cxx:163
BatchTensor base_index(const TorchSlice &indices) const
Return an index sliced on the base dimensions.
Definition BatchTensorBase.cxx:193
Derived to(const torch::TensorOptions &options) const
Send to options.
Definition BatchTensorBase.cxx:331
TorchSize batch_size(TorchSize index) const
Return the length of some batch axis.
Definition BatchTensorBase.cxx:156
Derived batch_expand(TorchShapeRef batch_size) const
Return a new view of the tensor with values broadcast along the batch dimensions.
Definition BatchTensorBase.cxx:219
Derived batch_unsqueeze(TorchSize d) const
Unsqueeze a batch dimension.
Definition BatchTensorBase.cxx:267
void base_index_put(const TorchSlice &indices, const torch::Tensor &other)
Set a index sliced on the base dimensions to a value.
Definition BatchTensorBase.cxx:210
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:59
BatchTensor base_reshape(TorchShapeRef base_shape) const
Reshape base dimensions.
Definition BatchTensorBase.cxx:260
static Derived empty_like(const Derived &other)
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:52
TorchSize base_dim() const
Return the number of base dimensions.
Definition BatchTensorBase.cxx:142
BatchTensor base_movedim(TorchSize d1, TorchSize d2) const
Move two base dimensions.
Definition BatchTensorBase.cxx:308
BatchTensor base_unsqueeze(TorchSize d) const
Unsqueeze a base dimension.
Definition BatchTensorBase.cxx:282
static Derived full_like(const Derived &other, Real init)
Definition BatchTensorBase.cxx:73
BatchTensor base_expand_copy(TorchShapeRef base_size) const
Return a new tensor with values broadcast along the base dimensions.
Definition BatchTensorBase.cxx:246
Derived list_unsqueeze() const
Unsqueeze on the special list batch dimension.
Definition BatchTensorBase.cxx:275
TorchSize base_size(TorchSize index) const
Return the length of some base axis.
Definition BatchTensorBase.cxx:170
void batch_index_put(TorchSlice indices, const torch::Tensor &other)
Set a index sliced on the batch dimensions to a value.
Definition BatchTensorBase.cxx:202
Derived batch_expand_copy(TorchShapeRef batch_size) const
Return a new tensor with values broadcast along the batch dimensions.
Definition BatchTensorBase.cxx:239
BatchTensorBase()=default
Default constructor.
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:66
Derived batch_reshape(TorchShapeRef batch_shape) const
Reshape batch dimensions.
Definition BatchTensorBase.cxx:253
Derived batch_sum(TorchSize d) const
Sum on a batch index.
Definition BatchTensorBase.cxx:345
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The (logical) scalar.
Definition Scalar.h:38
Derived pow(const Derived &a, const Real &n)
Definition BatchTensorBase.h:332
TorchSize storage_size(TorchShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:32
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
int64_t TorchSize
Definition types.h:35
TorchSize broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
double Real
Definition types.h:33
torch::IntArrayRef TorchShapeRef
Definition types.h:37
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
std::vector< at::indexing::TensorIndex > TorchSlice
Definition types.h:39