NEML2 1.4.0
Loading...
Searching...
No Matches
BatchTensorBase.cxx
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#include "neml2/tensors/BatchTensorBase.h"
26#include "neml2/tensors/tensors.h"
27#include "neml2/tensors/macros.h"
28
29namespace neml2
30{
31template <class Derived>
32BatchTensorBase<Derived>::BatchTensorBase(const torch::Tensor & tensor, TorchSize batch_dim)
33 : torch::Tensor(tensor),
34 _batch_dim(batch_dim)
35{
36 neml_assert_dbg((TorchSize)sizes().size() >= _batch_dim,
37 "Tensor dimension ",
38 sizes().size(),
39 " is smaller than the requested number of batch dimensions ",
40 _batch_dim);
41}
42
43template <class Derived>
45 : torch::Tensor(tensor),
46 _batch_dim(tensor.batch_dim())
47{
48}
49
50template <class Derived>
53{
54 return Derived(torch::empty_like(other), other.batch_dim());
56
57template <class Derived>
61 return Derived(torch::zeros_like(other), other.batch_dim());
63
64template <class Derived>
68 return Derived(torch::ones_like(other), other.batch_dim());
69}
70
71template <class Derived>
74{
75 return Derived(torch::full_like(other, init), other.batch_dim());
76}
77
78template <class Derived>
81 const Derived & start, const Derived & end, TorchSize nstep, TorchSize dim, TorchSize batch_dim)
82{
84 neml_assert_dbg(nstep > 0, "nstep must be positive.");
85
86 using namespace torch::indexing;
87
88 auto res = start.batch_unsqueeze(dim);
89
90 if (nstep > 1)
91 {
92 auto Bd = broadcast_batch_dim(start, end);
93 auto diff = (end - start).batch_unsqueeze(dim);
94
96 net.push_back(Ellipsis);
97 net.insert(net.end(), Bd - dim, None);
98 Scalar steps = torch::arange(nstep, diff.options()).index(net) / (nstep - 1);
99
100 res = res + steps * diff;
101 }
102
103 return Derived(res, batch_dim >= 0 ? batch_dim : res.batch_dim());
105
106template <class Derived>
109 const Derived & end,
112 TorchSize batch_dim,
114{
115 auto exponent = Derived::linspace(start, end, nstep, dim, batch_dim);
117}
118
119template <class Derived>
120bool
123 return _batch_dim;
124}
126template <class Derived>
129{
130 return _batch_dim;
132
133template <class Derived>
136{
137 return _batch_dim;
138}
139
140template <class Derived>
144 return dim() - batch_dim();
145}
147template <class Derived>
150{
151 return sizes().slice(0, _batch_dim);
152}
153
154template <class Derived>
158 return batch_sizes()[index >= 0 ? index : index + batch_dim()];
159}
161template <class Derived>
164{
165 return sizes().slice(_batch_dim);
167
168template <class Derived>
171{
172 return base_sizes()[index >= 0 ? index : index + base_dim()];
173}
174
175template <class Derived>
179 return utils::storage_size(base_sizes());
180}
182template <class Derived>
185{
186 indices.insert(indices.end(), base_dim(), torch::indexing::Slice());
187 auto res = this->index(indices);
188 return Derived(res, res.dim() - base_dim());
189}
191template <class Derived>
194{
195 TorchSlice indices2(batch_dim(), torch::indexing::Slice());
196 indices2.insert(indices2.end(), indices.begin(), indices.end());
197 return BatchTensor(this->index(indices2), batch_dim());
198}
200template <class Derived>
201void
203{
204 indices.insert(indices.end(), base_dim(), torch::indexing::Slice());
205 this->index_put_(indices, other);
206}
207
208template <class Derived>
209void
210BatchTensorBase<Derived>::base_index_put(const TorchSlice & indices, const torch::Tensor & other)
211{
212 TorchSlice indices2(batch_dim(), torch::indexing::Slice());
213 indices2.insert(indices2.end(), indices.begin(), indices.end());
214 this->index_put_(indices2, other);
215}
216
217template <class Derived>
220{
221 // We don't want to touch the base dimensions, so put -1 for them.
222 auto net = batch_size.vec();
223 net.insert(net.end(), base_dim(), -1);
224 return Derived(expand(net), batch_size.size());
225}
226
227template <class Derived>
230{
231 // We don't want to touch the batch dimensions, so put -1 for them.
232 auto net = base_size.vec();
233 net.insert(net.begin(), batch_dim(), -1);
234 return BatchTensor(expand(net), batch_dim());
235}
236
237template <class Derived>
240{
241 return Derived(batch_expand(batch_size).contiguous(), batch_size.size());
242}
243
244template <class Derived>
247{
248 return BatchTensor(base_expand(base_size).contiguous(), batch_dim());
249}
250
251template <class Derived>
257
258template <class Derived>
264
265template <class Derived>
268{
269 auto d2 = d >= 0 ? d : d - base_dim();
270 return Derived(unsqueeze(d2), _batch_dim + 1);
271}
272
273template <class Derived>
276{
277 return batch_unsqueeze(-1);
278}
279
280template <class Derived>
283{
284 auto d2 = d < 0 ? d : d + batch_dim();
285 return BatchTensor(torch::Tensor::unsqueeze(d2), batch_dim());
286}
287
288template <class Derived>
291{
292 return Derived(
293 torch::Tensor::transpose(d1 < 0 ? d1 - base_dim() : d1, d2 < 0 ? d2 - base_dim() : d2),
294 _batch_dim);
295}
296
297template <class Derived>
300{
301 return BatchTensor(
302 torch::Tensor::transpose(d1 < 0 ? d1 : _batch_dim + d1, d2 < 0 ? d2 : _batch_dim + d2),
303 _batch_dim);
304}
305
306template <class Derived>
309{
310 return BatchTensor(
311 torch::Tensor::movedim(d1 < 0 ? d1 : _batch_dim + d1, d2 < 0 ? d2 : _batch_dim + d2),
312 _batch_dim);
313}
314
315template <class Derived>
318{
319 return Derived(torch::Tensor::clone(memory_format), _batch_dim);
320}
321
322template <class Derived>
325{
326 return Derived(torch::Tensor::detach(), _batch_dim);
327}
328
329template <class Derived>
331BatchTensorBase<Derived>::to(const torch::TensorOptions & options) const
332{
333 return Derived(torch::Tensor::to(options), _batch_dim);
334}
335
336template <class Derived>
339{
340 return Derived(-torch::Tensor(*this), _batch_dim);
341}
342
343template <class Derived>
346{
347 neml_assert_dbg(_batch_dim > 0, "Must have a batch dimension to sum along");
348 auto d2 = d >= 0 ? d : d - base_dim();
349 return Derived(torch::sum(*this, d2), _batch_dim - 1);
350}
351
352template <class Derived>
355{
356 return batch_sum(-1);
357}
358
359#define BATCHTENSORBASE_INSTANTIATE(T) template class BatchTensorBase<T>
360FOR_ALL_BATCHTENSORBASE(BATCHTENSORBASE_INSTANTIATE);
361} // end namespace neml2
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Clone (take ownership)
Definition BatchTensorBase.cxx:317
TorchSize base_storage() const
Return the flattened storage needed just for the base indices.
Definition BatchTensorBase.cxx:177
TorchShapeRef batch_sizes() const
Return the batch size.
Definition BatchTensorBase.cxx:149
Derived list_sum() const
Sum on the list index (TODO: replace with class)
Definition BatchTensorBase.cxx:354
Derived batch_transpose(TorchSize d1, TorchSize d2) const
Transpose two batch dimensions.
Definition BatchTensorBase.cxx:290
bool batched() const
Whether the tensor is batched.
Definition BatchTensorBase.cxx:121
BatchTensor base_transpose(TorchSize d1, TorchSize d2) const
Transpose two base dimensions.
Definition BatchTensorBase.cxx:299
static Derived linspace(const Derived &start, const Derived &end, TorchSize nstep, TorchSize dim=0, TorchSize batch_dim=-1)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition BatchTensorBase.cxx:80
Derived detach() const
Discard function graph.
Definition BatchTensorBase.cxx:324
Derived operator-() const
Negation.
Definition BatchTensorBase.cxx:338
TorchSize batch_dim() const
Return the number of batch dimensions.
Definition BatchTensorBase.cxx:128
static Derived logspace(const Derived &start, const Derived &end, TorchSize nstep, TorchSize dim=0, TorchSize batch_dim=-1, Real base=10)
log-space equivalent of the linspace named constructor
Definition BatchTensorBase.cxx:108
Derived batch_index(TorchSlice indices) const
Get a batch.
Definition BatchTensorBase.cxx:184
BatchTensor base_expand(TorchShapeRef base_size) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition BatchTensorBase.cxx:229
TorchShapeRef base_sizes() const
Return the base size.
Definition BatchTensorBase.cxx:163
BatchTensor base_index(const TorchSlice &indices) const
Return an index sliced on the base dimensions.
Definition BatchTensorBase.cxx:193
Derived to(const torch::TensorOptions &options) const
Send to options.
Definition BatchTensorBase.cxx:331
TorchSize batch_size(TorchSize index) const
Return the length of some batch axis.
Definition BatchTensorBase.cxx:156
Derived batch_expand(TorchShapeRef batch_size) const
Return a new view of the tensor with values broadcast along the batch dimensions.
Definition BatchTensorBase.cxx:219
Derived batch_unsqueeze(TorchSize d) const
Unsqueeze a batch dimension.
Definition BatchTensorBase.cxx:267
void base_index_put(const TorchSlice &indices, const torch::Tensor &other)
Set a index sliced on the base dimensions to a value.
Definition BatchTensorBase.cxx:210
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:59
BatchTensor base_reshape(TorchShapeRef base_shape) const
Reshape base dimensions.
Definition BatchTensorBase.cxx:260
static Derived empty_like(const Derived &other)
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:52
TorchSize base_dim() const
Return the number of base dimensions.
Definition BatchTensorBase.cxx:142
BatchTensor base_movedim(TorchSize d1, TorchSize d2) const
Move two base dimensions.
Definition BatchTensorBase.cxx:308
BatchTensor base_unsqueeze(TorchSize d) const
Unsqueeze a base dimension.
Definition BatchTensorBase.cxx:282
static Derived full_like(const Derived &other, Real init)
Definition BatchTensorBase.cxx:73
BatchTensor base_expand_copy(TorchShapeRef base_size) const
Return a new tensor with values broadcast along the base dimensions.
Definition BatchTensorBase.cxx:246
Derived list_unsqueeze() const
Unsqueeze on the special list batch dimension.
Definition BatchTensorBase.cxx:275
TorchSize base_size(TorchSize index) const
Return the length of some base axis.
Definition BatchTensorBase.cxx:170
void batch_index_put(TorchSlice indices, const torch::Tensor &other)
Set a index sliced on the batch dimensions to a value.
Definition BatchTensorBase.cxx:202
Derived batch_expand_copy(TorchShapeRef batch_size) const
Return a new tensor with values broadcast along the batch dimensions.
Definition BatchTensorBase.cxx:239
BatchTensorBase()=default
Default constructor.
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:66
Derived batch_reshape(TorchShapeRef batch_shape) const
Reshape batch dimensions.
Definition BatchTensorBase.cxx:253
Derived batch_sum(TorchSize d) const
Sum on a batch index.
Definition BatchTensorBase.cxx:345
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
CrossRef()=default
The (logical) scalar.
Definition Scalar.h:38
Derived pow(const Derived &a, const Real &n)
Definition BatchTensorBase.h:332
TorchSize storage_size(TorchShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:32
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
int64_t TorchSize
Definition types.h:33
TorchSize broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
double Real
Definition types.h:31
torch::IntArrayRef TorchShapeRef
Definition types.h:35
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
std::vector< at::indexing::TensorIndex > TorchSlice
Definition types.h:37