NEML2 1.4.0
Loading...
Searching...
No Matches
BatchTensorBase.h
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/misc/utils.h"
28
29namespace neml2
30{
31// Forward declarations
32template <class Derived>
33class BatchTensorBase;
34
35class BatchTensor;
36
44template <class Derived>
45class BatchTensorBase : public torch::Tensor
46{
47public:
49 BatchTensorBase() = default;
50
52 BatchTensorBase(const torch::Tensor & tensor, TorchSize batch_dim);
53
55 BatchTensorBase(const Derived & tensor);
56
58
60 [[nodiscard]] static Derived empty_like(const Derived & other);
62 [[nodiscard]] static Derived zeros_like(const Derived & other);
64 [[nodiscard]] static Derived ones_like(const Derived & other);
67 [[nodiscard]] static Derived full_like(const Derived & other, Real init);
68
90 [[nodiscard]] static Derived linspace(const Derived & start,
91 const Derived & end,
93 TorchSize dim = 0,
94 TorchSize batch_dim = -1);
96 [[nodiscard]] static Derived logspace(const Derived & start,
97 const Derived & end,
99 TorchSize dim = 0,
100 TorchSize batch_dim = -1,
101 Real base = 10);
102
104 bool batched() const;
105
107 TorchSize batch_dim() const;
108
111
113 TorchSize base_dim() const;
114
117
119 TorchSize batch_size(TorchSize index) const;
120
123
125 TorchSize base_size(TorchSize index) const;
126
128 TorchSize base_storage() const;
129
131 Derived batch_index(TorchSlice indices) const;
132
134 BatchTensor base_index(const TorchSlice & indices) const;
135
137 void batch_index_put(TorchSlice indices, const torch::Tensor & other);
138
140 void base_index_put(const TorchSlice & indices, const torch::Tensor & other);
141
144
147
149 template <class Derived2>
151
153 template <class Derived2>
155
158
161
164
167
170
172 Derived list_unsqueeze() const;
173
176
179
182
185
187 Derived clone(torch::MemoryFormat memory_format = torch::MemoryFormat::Contiguous) const;
188
190 Derived detach() const;
191
193 Derived to(const torch::TensorOptions & options) const;
194
196 Derived operator-() const;
197
199 Derived batch_sum(TorchSize d) const;
200
202 Derived list_sum() const;
203
204private:
206 TorchSize _batch_dim;
207};
208
209template <class Derived>
210template <class Derived2>
213{
214 return batch_expand(other.batch_sizes());
215}
216
217template <class Derived>
218template <class Derived2>
221{
222 return base_expand(other.base_sizes());
223}
224
225template <class Derived,
226 typename = typename std::enable_if<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
228operator+(const Derived & a, const Real & b)
229{
230 return Derived(torch::operator+(a, b), a.batch_dim());
231}
232
233template <
234 class Derived,
235 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
236Derived
237operator+(const Real & a, const Derived & b)
238{
239 return b + a;
240}
241
242template <
243 class Derived,
244 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
245Derived
246operator+(const Derived & a, const Derived & b)
247{
249 return Derived(torch::operator+(a, b), broadcast_batch_dim(a, b));
250}
251
252template <
253 class Derived,
254 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
255Derived
256operator-(const Derived & a, const Real & b)
257{
258 return Derived(torch::operator-(a, b), a.batch_dim());
259}
260
261template <
262 class Derived,
263 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
264Derived
265operator-(const Real & a, const Derived & b)
266{
267 return -b + a;
268}
269
270template <
271 class Derived,
272 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
273Derived
274operator-(const Derived & a, const Derived & b)
275{
277 return Derived(torch::operator-(a, b), broadcast_batch_dim(a, b));
278}
279
280template <
281 class Derived,
282 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
283Derived
284operator*(const Derived & a, const Real & b)
285{
286 return Derived(torch::operator*(a, b), a.batch_dim());
287}
288
289template <
290 class Derived,
291 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
292Derived
293operator*(const Real & a, const Derived & b)
294{
295 return b * a;
296}
297
298template <
299 class Derived,
300 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
301Derived
302operator/(const Derived & a, const Real & b)
303{
304 return Derived(torch::operator/(a, b), a.batch_dim());
305}
306
307template <
308 class Derived,
309 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
310Derived
311operator/(const Real & a, const Derived & b)
312{
313 return Derived(torch::operator/(a, b), b.batch_dim());
314}
315
316template <
317 class Derived,
318 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
319Derived
320operator/(const Derived & a, const Derived & b)
321{
323 return Derived(torch::operator/(a, b), broadcast_batch_dim(a, b));
324}
325
326namespace math
327{
328template <
329 class Derived,
330 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
331Derived
332pow(const Derived & a, const Real & n)
333{
334 return Derived(torch::pow(a, n), a.batch_dim());
335}
336
337template <
338 class Derived,
339 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
341pow(const Real & a, const Derived & n)
342{
343 return Derived(torch::pow(a, n), n.batch_dim());
344}
345
346template <
347 class Derived,
348 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
350pow(const Derived & a, const Derived & n)
351{
353 return Derived(torch::pow(a, n), broadcast_batch_dim(a, n));
354}
355
356template <
357 class Derived,
358 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
360sign(const Derived & a)
361{
362 return Derived(torch::sign(a), a.batch_dim());
363}
364
365template <
366 class Derived,
367 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
369cosh(const Derived & a)
370{
371 return Derived(torch::cosh(a), a.batch_dim());
372}
373
374template <
375 class Derived,
376 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
378sinh(const Derived & a)
379{
380 return Derived(torch::sinh(a), a.batch_dim());
381}
382
383template <
384 class Derived,
385 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
387tanh(const Derived & a)
388{
389 return Derived(torch::tanh(a), a.batch_dim());
390}
391
392template <
393 class Derived,
394 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
396where(const torch::Tensor & condition, const Derived & a, const Derived & b)
397{
399 return Derived(torch::where(condition, a, b), broadcast_batch_dim(a, b));
400}
401
408template <
409 class Derived,
410 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
413{
414 return (sign(a) + 1.0) / 2.0;
415}
416
417template <
418 class Derived,
419 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
422{
423 return Derived(torch::Tensor(a) * torch::Tensor(heaviside(a)), a.batch_dim());
424}
425
426template <
427 class Derived,
428 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
431{
432 return heaviside(a);
433}
434
435template <
436 class Derived,
437 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
439sqrt(const Derived & a)
440{
441 return Derived(torch::sqrt(a), a.batch_dim());
442}
443
444template <
445 class Derived,
446 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
448exp(const Derived & a)
449{
450 return Derived(torch::exp(a), a.batch_dim());
451}
452
453template <
454 class Derived,
455 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
457abs(const Derived & a)
458{
459 return Derived(torch::abs(a), a.batch_dim());
460}
461
462template <
463 class Derived,
464 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
466diff(const Derived & a, TorchSize n = 1, TorchSize dim = -1)
467{
468 return Derived(torch::diff(a, n, dim), a.batch_dim());
469}
470
471template <
472 class Derived,
473 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
476{
477 return Derived(torch::diag_embed(
478 a, offset, d1 < 0 ? d1 - a.base_dim() : d1, d2 < 0 ? d2 - a.base_dim() : d2),
479 a.batch_dim() + 1);
480}
481
482template <
483 class Derived,
484 typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
486log(const Derived & a)
487{
488 return Derived(torch::log(a), a.batch_dim());
489}
490
491} // namespace math
492} // namespace neml2
NEML2's enhanced tensor type.
Definition BatchTensorBase.h:46
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Clone (take ownership)
Definition BatchTensorBase.cxx:317
TorchSize base_storage() const
Return the flattened storage needed just for the base indices.
Definition BatchTensorBase.cxx:177
TorchShapeRef batch_sizes() const
Return the batch size.
Definition BatchTensorBase.cxx:149
Derived batch_expand_as(const Derived2 &other) const
Expand the batch to have the same shape as another tensor.
Definition BatchTensorBase.h:212
Derived list_sum() const
Sum on the list index (TODO: replace with class)
Definition BatchTensorBase.cxx:354
Derived batch_transpose(TorchSize d1, TorchSize d2) const
Transpose two batch dimensions.
Definition BatchTensorBase.cxx:290
bool batched() const
Whether the tensor is batched.
Definition BatchTensorBase.cxx:121
BatchTensor base_transpose(TorchSize d1, TorchSize d2) const
Transpose two base dimensions.
Definition BatchTensorBase.cxx:299
static Derived linspace(const Derived &start, const Derived &end, TorchSize nstep, TorchSize dim=0, TorchSize batch_dim=-1)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition BatchTensorBase.cxx:80
Derived detach() const
Discard function graph.
Definition BatchTensorBase.cxx:324
Derived operator-() const
Negation.
Definition BatchTensorBase.cxx:338
TorchSize batch_dim() const
Return the number of batch dimensions.
Definition BatchTensorBase.cxx:128
static Derived logspace(const Derived &start, const Derived &end, TorchSize nstep, TorchSize dim=0, TorchSize batch_dim=-1, Real base=10)
log-space equivalent of the linspace named constructor
Definition BatchTensorBase.cxx:108
Derived batch_index(TorchSlice indices) const
Get a batch.
Definition BatchTensorBase.cxx:184
BatchTensor base_expand(TorchShapeRef base_size) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition BatchTensorBase.cxx:229
TorchShapeRef base_sizes() const
Return the base size.
Definition BatchTensorBase.cxx:163
BatchTensor base_index(const TorchSlice &indices) const
Return an index sliced on the base dimensions.
Definition BatchTensorBase.cxx:193
Derived to(const torch::TensorOptions &options) const
Send to options.
Definition BatchTensorBase.cxx:331
TorchSize batch_size(TorchSize index) const
Return the length of some batch axis.
Definition BatchTensorBase.cxx:156
Derived batch_expand(TorchShapeRef batch_size) const
Return a new view of the tensor with values broadcast along the batch dimensions.
Definition BatchTensorBase.cxx:219
Derived batch_unsqueeze(TorchSize d) const
Unsqueeze a batch dimension.
Definition BatchTensorBase.cxx:267
void base_index_put(const TorchSlice &indices, const torch::Tensor &other)
Set a index sliced on the base dimensions to a value.
Definition BatchTensorBase.cxx:210
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:59
BatchTensor base_reshape(TorchShapeRef base_shape) const
Reshape base dimensions.
Definition BatchTensorBase.cxx:260
static Derived empty_like(const Derived &other)
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:52
TorchSize base_dim() const
Return the number of base dimensions.
Definition BatchTensorBase.cxx:142
BatchTensor base_movedim(TorchSize d1, TorchSize d2) const
Move two base dimensions.
Definition BatchTensorBase.cxx:308
BatchTensor base_unsqueeze(TorchSize d) const
Unsqueeze a base dimension.
Definition BatchTensorBase.cxx:282
static Derived full_like(const Derived &other, Real init)
Definition BatchTensorBase.cxx:73
BatchTensorBase(Real)=delete
BatchTensor base_expand_copy(TorchShapeRef base_size) const
Return a new tensor with values broadcast along the base dimensions.
Definition BatchTensorBase.cxx:246
Derived list_unsqueeze() const
Unsqueeze on the special list batch dimension.
Definition BatchTensorBase.cxx:275
TorchSize base_size(TorchSize index) const
Return the length of some base axis.
Definition BatchTensorBase.cxx:170
void batch_index_put(TorchSlice indices, const torch::Tensor &other)
Set a index sliced on the batch dimensions to a value.
Definition BatchTensorBase.cxx:202
Derived batch_expand_copy(TorchShapeRef batch_size) const
Return a new tensor with values broadcast along the batch dimensions.
Definition BatchTensorBase.cxx:239
BatchTensorBase()=default
Default constructor.
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:66
Derived batch_reshape(TorchShapeRef batch_shape) const
Reshape batch dimensions.
Definition BatchTensorBase.cxx:253
Derived2 base_expand_as(const Derived2 &other) const
Expand the base to have the same shape as another tensor.
Definition BatchTensorBase.h:220
Derived batch_sum(TorchSize d) const
Sum on a batch index.
Definition BatchTensorBase.cxx:345
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
Derived macaulay(const Derived &a)
Definition BatchTensorBase.h:421
Derived sinh(const Derived &a)
Definition BatchTensorBase.h:378
Derived heaviside(const Derived &a)
Definition BatchTensorBase.h:412
Derived tanh(const Derived &a)
Definition BatchTensorBase.h:387
Derived cosh(const Derived &a)
Definition BatchTensorBase.h:369
Derived diff(const Derived &a, TorchSize n=1, TorchSize dim=-1)
Definition BatchTensorBase.h:466
Derived where(const torch::Tensor &condition, const Derived &a, const Derived &b)
Definition BatchTensorBase.h:396
Derived log(const Derived &a)
Definition BatchTensorBase.h:486
Derived abs(const Derived &a)
Definition BatchTensorBase.h:457
Derived dmacaulay(const Derived &a)
Definition BatchTensorBase.h:430
Derived exp(const Derived &a)
Definition BatchTensorBase.h:448
Derived batch_diag_embed(const Derived &a, TorchSize offset=0, TorchSize d1=-2, TorchSize d2=-1)
Definition BatchTensorBase.h:475
Derived sign(const Derived &a)
Definition BatchTensorBase.h:360
Derived sqrt(const Derived &a)
Definition BatchTensorBase.h:439
Derived pow(const Derived &a, const Real &n)
Definition BatchTensorBase.h:332
Definition CrossRef.cxx:32
Derived operator-(const Derived &a, const Real &b)
Definition BatchTensorBase.h:256
BatchTensor operator*(const BatchTensor &a, const BatchTensor &b)
Definition BatchTensor.cxx:153
int64_t TorchSize
Definition types.h:35
TorchSize broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
double Real
Definition types.h:33
Derived operator+(const Derived &a, const Real &b)
Definition BatchTensorBase.h:228
torch::IntArrayRef TorchShapeRef
Definition types.h:37
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
std::vector< at::indexing::TensorIndex > TorchSlice
Definition types.h:39
Derived operator/(const Derived &a, const Real &b)
Definition BatchTensorBase.h:302