NEML2 1.4.0
Loading...
Searching...
No Matches
TensorBase.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/misc/utils.h"
28
29namespace neml2
30{
31// Forward declarations
32template <class Derived>
33class TensorBase;
34
35class Tensor;
36
44template <class Derived>
45class TensorBase : public torch::Tensor
46{
47public:
49 TensorBase() = default;
50
52 TensorBase(const torch::Tensor & tensor, Size batch_dim);
53
55 TensorBase(const Derived & tensor);
56
57 TensorBase(Real) = delete;
58
60 [[nodiscard]] static Derived empty_like(const Derived & other);
62 [[nodiscard]] static Derived zeros_like(const Derived & other);
64 [[nodiscard]] static Derived ones_like(const Derived & other);
67 [[nodiscard]] static Derived full_like(const Derived & other, Real init);
68
90 [[nodiscard]] static Derived linspace(
91 const Derived & start, const Derived & end, Size nstep, Size dim = 0, Size batch_dim = -1);
93 [[nodiscard]] static Derived logspace(const Derived & start,
94 const Derived & end,
95 Size nstep,
96 Size dim = 0,
97 Size batch_dim = -1,
98 Real base = 10);
99
101 // These methods should be mirrored in LabeledTensor
104 Derived clone(torch::MemoryFormat memory_format = torch::MemoryFormat::Contiguous) const;
106 Derived detach() const;
108 using torch::Tensor::detach_;
110 Derived to(const torch::TensorOptions & options) const;
112 using torch::Tensor::copy_;
114 using torch::Tensor::zero_;
116 using torch::Tensor::requires_grad;
118 using torch::Tensor::requires_grad_;
120 Derived operator-() const;
122
124 // These methods should be mirrored in LabeledTensor
127 using torch::Tensor::options;
129 using torch::Tensor::scalar_type;
131 using torch::Tensor::device;
133 using torch::Tensor::dim;
135 using torch::Tensor::sizes;
137 using torch::Tensor::size;
139 bool batched() const;
141 Size batch_dim() const;
143 Size base_dim() const;
147 Size batch_size(Size index) const;
151 Size base_size(Size index) const;
153 Size base_storage() const;
155
157 // These methods should be mirrored in LabeledTensor
160 using torch::Tensor::index;
161 using torch::Tensor::index_put_;
167 void batch_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor & other);
169 void base_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor & other);
171
179 template <class Derived2>
182 template <class Derived2>
201
202private:
204 Size _batch_dim;
205};
206
207template <class Derived>
208template <class Derived2>
211{
212 return batch_expand(other.batch_sizes());
213}
214
215template <class Derived>
216template <class Derived2>
219{
220 return base_expand(other.base_sizes());
221}
222
223template <class Derived,
224 typename = typename std::enable_if<std::is_base_of_v<TensorBase<Derived>, Derived>>>
226operator+(const Derived & a, const Real & b)
227{
228 return Derived(torch::operator+(a, b), a.batch_dim());
229}
230
231template <class Derived,
232 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
233Derived
234operator+(const Real & a, const Derived & b)
235{
236 return b + a;
237}
238
239template <class Derived,
240 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
241Derived
242operator+(const Derived & a, const Derived & b)
243{
245 return Derived(torch::operator+(a, b), broadcast_batch_dim(a, b));
246}
247
248template <class Derived,
249 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
250Derived
251operator-(const Derived & a, const Real & b)
252{
253 return Derived(torch::operator-(a, b), a.batch_dim());
254}
255
256template <class Derived,
257 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
258Derived
259operator-(const Real & a, const Derived & b)
260{
261 return -b + a;
262}
263
264template <class Derived,
265 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
266Derived
267operator-(const Derived & a, const Derived & b)
268{
270 return Derived(torch::operator-(a, b), broadcast_batch_dim(a, b));
271}
272
273template <class Derived,
274 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
275Derived
276operator*(const Derived & a, const Real & b)
277{
278 return Derived(torch::operator*(a, b), a.batch_dim());
279}
280
281template <class Derived,
282 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
283Derived
284operator*(const Real & a, const Derived & b)
285{
286 return b * a;
287}
288
289template <class Derived,
290 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
291Derived
292operator/(const Derived & a, const Real & b)
293{
294 return Derived(torch::operator/(a, b), a.batch_dim());
295}
296
297template <class Derived,
298 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
299Derived
300operator/(const Real & a, const Derived & b)
301{
302 return Derived(torch::operator/(a, b), b.batch_dim());
303}
304
305template <class Derived,
306 typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
307Derived
308operator/(const Derived & a, const Derived & b)
309{
311 return Derived(torch::operator/(a, b), broadcast_batch_dim(a, b));
312}
313} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
NEML2's enhanced tensor type.
Definition TensorBase.h:46
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Definition TensorBase.cxx:114
Derived batch_expand_copy(TensorShapeRef batch_size) const
Return a new tensor with values broadcast along the batch dimensions.
Definition TensorBase.cxx:250
Derived batch_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the batch dimensions.
Definition TensorBase.cxx:191
neml2::Tensor base_transpose(Size d1, Size d2) const
Transpose two base dimensions.
Definition TensorBase.cxx:303
static Derived logspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, Size batch_dim=-1, Real base=10)
log-space equivalent of the linspace named constructor
Definition TensorBase.cxx:105
Derived batch_expand_as(const Derived2 &other) const
Expand the batch to have the same shape as another tensor.
Definition TensorBase.h:210
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition TensorBase.cxx:184
TensorBase()=default
Default constructor.
bool batched() const
Whether the tensor is batched.
Definition TensorBase.cxx:135
Size batch_size(Size index) const
Return the size of a batch axis.
Definition TensorBase.cxx:163
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBase.cxx:142
Derived detach() const
Discard function graph.
Definition TensorBase.cxx:121
Derived batch_expand(TensorShapeRef batch_size) const
Definition TensorBase.cxx:230
neml2::Tensor base_expand(TensorShapeRef base_size) const
Return a new view of the tensor with values broadcast along the base dimensions.
Definition TensorBase.cxx:240
Derived operator-() const
Negation.
Definition TensorBase.cxx:312
TensorBase(Real)=delete
TensorShapeRef batch_sizes() const
Return the batch size.
Definition TensorBase.cxx:156
Derived batch_transpose(Size d1, Size d2) const
Transpose two batch dimensions.
Definition TensorBase.cxx:294
neml2::Tensor base_unsqueeze(Size d) const
Unsqueeze a base dimension.
Definition TensorBase.cxx:286
Derived batch_reshape(TensorShapeRef batch_shape) const
Reshape batch dimensions.
Definition TensorBase.cxx:264
neml2::Tensor base_expand_copy(TensorShapeRef base_size) const
Return a new tensor with values broadcast along the base dimensions.
Definition TensorBase.cxx:257
Derived to(const torch::TensorOptions &options) const
Change tensor options.
Definition TensorBase.cxx:128
Size base_size(Size index) const
Return the size of a base axis.
Definition TensorBase.cxx:177
static Derived zeros_like(const Derived &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBase.cxx:58
static Derived empty_like(const Derived &other)
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBase.cxx:51
Derived batch_unsqueeze(Size d) const
Unsqueeze a batch dimension.
Definition TensorBase.cxx:278
static Derived full_like(const Derived &other, Real init)
Definition TensorBase.cxx:72
void batch_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Set values by slicing on the batch dimensions.
Definition TensorBase.cxx:210
Size base_dim() const
Return the number of base dimensions.
Definition TensorBase.cxx:149
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBase.cxx:170
static Derived linspace(const Derived &start, const Derived &end, Size nstep, Size dim=0, Size batch_dim=-1)
Create a new tensor by adding a new batch dimension with linear spacing between start and end.
Definition TensorBase.cxx:79
neml2::Tensor base_reshape(TensorShapeRef base_shape) const
Reshape base dimensions.
Definition TensorBase.cxx:271
void base_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Set values by slicing on the base dimensions.
Definition TensorBase.cxx:220
static Derived ones_like(const Derived &other)
Unit tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition TensorBase.cxx:65
neml2::Tensor base_index(indexing::TensorIndicesRef indices) const
Get a tensor by slicing on the base dimensions.
Definition TensorBase.cxx:201
Derived2 base_expand_as(const Derived2 &other) const
Expand the base to have the same shape as another tensor.
Definition TensorBase.h:218
Definition Tensor.h:32
torch::ArrayRef< TensorIndex > TensorIndicesRef
Definition types.h:42
Definition CrossRef.cxx:30
Vec operator*(const Derived1 &A, const Derived2 &b)
matrix-vector product
Definition R2Base.cxx:233
Derived operator-(const Derived &a, const Scalar &b)
Definition Scalar.h:79
Derived operator+(const Derived &a, const Scalar &b)
Definition Scalar.h:58
Size broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
double Real
Definition types.h:31
int64_t Size
Definition types.h:33
Derived operator/(const Derived &a, const Scalar &b)
Definition Scalar.h:123
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
torch::IntArrayRef TensorShapeRef
Definition types.h:35