NEML2 1.4.0
Loading...
Searching...
No Matches
math.cxx
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#include "neml2/misc/math.h"
26#include "neml2/misc/error.h"
27#include "neml2/tensors/tensors.h"
28
29namespace neml2
30{
31namespace math
32{
34{
35 _full_to_mandel_map = torch::tensor({0, 4, 8, 5, 2, 1}, default_integer_tensor_options());
36
37 _mandel_to_full_map =
38 torch::tensor({0, 5, 4, 5, 1, 3, 4, 3, 2}, default_integer_tensor_options());
39
40 _full_to_mandel_factor =
41 torch::tensor({1.0, 1.0, 1.0, sqrt2, sqrt2, sqrt2}, default_tensor_options());
42
43 _mandel_to_full_factor =
44 torch::tensor({1.0, invsqrt2, invsqrt2, invsqrt2, 1.0, invsqrt2, invsqrt2, invsqrt2, 1.0},
46
47 _full_to_skew_map = torch::tensor({7, 2, 3}, default_integer_tensor_options());
48
49 _skew_to_full_map = torch::tensor({0, 2, 1, 2, 0, 0, 1, 0, 0}, default_integer_tensor_options());
50
51 _full_to_skew_factor = torch::tensor({1.0, 1.0, 1.0}, default_tensor_options());
52
53 _skew_to_full_factor =
54 torch::tensor({0.0, -1.0, 1.0, 1.0, 0.0, -1.0, -1.0, 1.0, 0.0}, default_tensor_options());
55}
56
59{
60 static ConstantTensors cts;
61 return cts;
62}
63
64const torch::Tensor &
66{
67 return get()._full_to_mandel_map;
68}
69
70const torch::Tensor &
72{
73 return get()._mandel_to_full_map;
74}
75
76const torch::Tensor &
78{
79 return get()._full_to_mandel_factor;
80}
81
82const torch::Tensor &
84{
85 return get()._mandel_to_full_factor;
86}
87
88const torch::Tensor &
90{
91 return get()._full_to_skew_map;
92}
93
94const torch::Tensor &
96{
97 return get()._skew_to_full_map;
98}
99
100const torch::Tensor &
102{
103 return get()._full_to_skew_factor;
104}
105
106const torch::Tensor &
108{
109 return get()._skew_to_full_factor;
110}
111
114 const torch::Tensor & rmap,
115 const torch::Tensor & rfactors,
117{
118 using namespace torch::indexing;
119
120 auto batch_dim = full.batch_dim();
121 auto starting_dim = batch_dim + dim;
122 auto trailing_dim = full.dim() - starting_dim - 2; // 2 comes from the reduced axes (3,3)
123 auto starting_shape = full.sizes().slice(0, starting_dim);
124 auto trailing_shape = full.sizes().slice(starting_dim + 2);
125
127 net.push_back(Ellipsis);
128 net.insert(net.end(), trailing_dim, None);
129 auto map =
130 rmap.index(net).expand(utils::add_shapes(starting_shape, rmap.sizes()[0], trailing_shape));
131 auto factor = rfactors.to(full).index(net);
132
133 return BatchTensor(
134 factor * torch::gather(full.reshape(utils::add_shapes(starting_shape, 9, trailing_shape)),
136 map),
137 batch_dim);
138}
139
142 const torch::Tensor & rmap,
143 const torch::Tensor & rfactors,
145{
146 using namespace torch::indexing;
147
148 auto batch_dim = reduced.batch_dim();
149 auto starting_dim = batch_dim + dim;
150 auto trailing_dim = reduced.dim() - starting_dim - 1; // There's only 1 axis to unsqueeze
151 auto starting_shape = reduced.sizes().slice(0, starting_dim);
152 auto trailing_shape = reduced.sizes().slice(starting_dim + 1);
153
155 net.push_back(Ellipsis);
156 net.insert(net.end(), trailing_dim, None);
157 auto map = rmap.index(net).expand(utils::add_shapes(starting_shape, 9, trailing_shape));
158 auto factor = rfactors.to(reduced).index(net);
159
160 return BatchTensor((factor * torch::gather(reduced, starting_dim, map))
162 batch_dim);
163}
164
167{
168 return full_to_reduced(
169 full,
170 ConstantTensors::full_to_mandel_map().to(full.options().dtype(NEML2_INT_DTYPE)),
171 ConstantTensors::full_to_mandel_factor().to(full.options()),
172 dim);
173}
174
177{
178 return reduced_to_full(
179 mandel,
180 ConstantTensors::mandel_to_full_map().to(mandel.options().dtype(NEML2_INT_DTYPE)),
182 dim);
183}
184
187{
188 return full_to_reduced(
189 full,
190 ConstantTensors::full_to_skew_map().to(full.options().dtype(NEML2_INT_DTYPE)),
191 ConstantTensors::full_to_skew_factor().to(full.options()),
192 dim);
193}
194
197{
198 return reduced_to_full(
199 skew,
200 ConstantTensors::skew_to_full_map().to(skew.options().dtype(NEML2_INT_DTYPE)),
201 ConstantTensors::skew_to_full_factor().to(skew.options()),
202 dim);
203}
204
207{
208 neml_assert(p.batch_sizes() == y.batch_sizes(),
209 "The batch shape of the parameter must be the same as the batch shape "
210 "of the output. However, the batch shape of the parameter is ",
211 p.batch_sizes(),
212 ", and the batch shape of the output is ",
213 y.batch_sizes());
214
215 // flatten y to handle arbitrarily shaped output
216 auto yf = BatchTensor(
217 y.reshape(utils::add_shapes(y.batch_sizes(), utils::storage_size(y.base_sizes()))),
218 y.batch_dim());
219
220 neml_assert_dbg(yf.base_dim() == 1, "Flattened output must be flat.");
221
223 yf.batch_sizes(), utils::add_shapes(yf.base_sizes(), p.base_sizes()), yf.options());
224
225 for (TorchSize i = 0; i < yf.base_sizes()[0]; i++)
226 {
228 v.index_put_({torch::indexing::Ellipsis, i}, 1.0);
229 const auto dyfi_dp = torch::autograd::grad({yf},
230 {p},
231 {v},
232 /*retain_graph=*/true,
233 /*create_graph=*/false,
234 /*allow_unused=*/false)[0];
235 if (dyfi_dp.defined())
236 dyf_dp.base_index_put({i, torch::indexing::Ellipsis}, dyfi_dp);
237 }
238
239 // Reshape the derivative back to the correct shape
240 const auto dy_dp = BatchTensor(
241 dyf_dp.reshape(utils::add_shapes(y.batch_sizes(), y.base_sizes(), p.base_sizes())),
242 y.batch_dim());
243
244 return dy_dp;
245}
246
249{
250 return BatchTensor(
251 torch::diag_embed(
252 a, offset, d1 < 0 ? d1 : d1 + a.batch_dim() + 1, d2 < 0 ? d2 : d2 + a.batch_dim() + 1),
253 a.batch_dim());
254}
255
256SR2
257skew_and_sym_to_sym(const SR2 & e, const WR2 & w)
258{
259 // In NEML we used an unrolled form, I don't think I ever found
260 // a nice direct notation for this one
261 auto E = R2(e);
262 auto W = R2(w);
263 return SR2(W * E - E * W);
264}
265
266SSR4
268{
269 auto I = R2::identity(w.options());
270 auto W = R2(w);
271 return SSR4(R4(torch::einsum("...ia,...jb->...ijab", {W, I}) -
272 torch::einsum("...ia,...bj->...ijab", {I, W})));
273}
274
275SWR4
277{
278 auto I = R2::identity(e.options());
279 auto E = R2(e);
280 return SWR4(R4(torch::einsum("...ia,...bj->...ijab", {I, E}) -
281 torch::einsum("...ia,...jb->...ijab", {E, I})));
282}
283
284WR2
285multiply_and_make_skew(const SR2 & a, const SR2 & b)
286{
287 auto A = R2(a);
288 auto B = R2(b);
289
290 return WR2(A * B - B * A);
291}
292
293WSR4
295{
296 auto I = R2::identity(b.options());
297 auto B = R2(b);
298 return WSR4(R4(torch::einsum("...ia,...bj->...ijab", {I, B}) -
299 torch::einsum("...ia,...jb->...ijab", {B, I})));
300}
301
302WSR4
304{
305 auto I = R2::identity(a.options());
306 auto A = R2(a);
307 return WSR4(R4(torch::einsum("...ia,...jb->...ijab", {A, I}) -
308 torch::einsum("...ia,...bj->...ijab", {I, A})));
309}
310
311namespace linalg
312{
315{
316 neml_assert_dbg(v.base_dim() == 0 || v.base_dim() == 1,
317 "v in vector_norm has base dimension ",
318 v.base_dim(),
319 " instead of 0 or 1.");
320
321 // If the vector is a logical scalar just return its absolute value
322 if (v.base_dim() == 0)
323 return math::abs(v);
324
325 return BatchTensor(torch::linalg::vector_norm(
326 v, /*order=*/2, /*dim=*/-1, /*keepdim=*/false, /*dtype=*/c10::nullopt),
327 v.batch_dim());
328}
329
331solve(const BatchTensor & A, const BatchTensor & B)
332{
333 return BatchTensor(torch::linalg::solve(A, B, /*left=*/true), A.batch_dim());
334}
335
336std::tuple<BatchTensor, BatchTensor>
338{
339 auto [LU, pivots] = torch::linalg_lu_factor(A, pivot);
340 return {BatchTensor(LU, A.batch_dim()), BatchTensor(pivots, A.batch_dim())};
341}
342
345 const BatchTensor & pivots,
346 const BatchTensor & B,
347 bool left,
348 bool adjoint)
349{
350 return BatchTensor(torch::linalg_lu_solve(LU, pivots, B, left, adjoint), B.batch_dim());
351}
352} // namespace linalg
353} // namespace math
354} // namespace neml2
TorchSize batch_dim() const
Return the number of batch dimensions.
Definition BatchTensorBase.cxx:128
static BatchTensor zeros_like(const BatchTensor &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:59
Definition BatchTensor.h:32
static BatchTensor empty(const TorchShapeRef &base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition BatchTensor.cxx:30
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
static R2 identity(const torch::TensorOptions &options=default_tensor_options())
Identity.
Definition R2Base.cxx:170
A basic R2.
Definition R2.h:42
The (logical) full fourth order tensor.
Definition R4.h:43
The (logical) symmetric second order tensor.
Definition SR2.h:46
The (logical) symmetric fourth order tensor, with symmetry in the first two dimensionss as well as in...
Definition SSR4.h:44
The (logical) symmetric fourth order tensor, with symmetry in the first two dimensionss and skew-symm...
Definition SWR4.h:40
A skew rank 2, represented as an axial vector.
Definition WR2.h:43
The (logical) symmetric fourth order tensor, with skew symmetry in the first two dimensionss and symm...
Definition WSR4.h:40
BatchTensor solve(const BatchTensor &A, const BatchTensor &B)
Solve the linear system A X = B.
Definition math.cxx:331
std::tuple< BatchTensor, BatchTensor > lu_factor(const BatchTensor &A, bool pivot)
Definition math.cxx:337
BatchTensor vector_norm(const BatchTensor &v)
Vector norm of a vector. Falls back to math::abs is v is a Scalar.
Definition math.cxx:314
BatchTensor lu_solve(const BatchTensor &LU, const BatchTensor &pivots, const BatchTensor &B, bool left, bool adjoint)
Definition math.cxx:344
SWR4 d_skew_and_sym_to_sym_d_skew(const SR2 &e)
Derivative of w_ik e_kj - e_ik w_kj wrt. w.
Definition math.cxx:276
SSR4 d_skew_and_sym_to_sym_d_sym(const WR2 &w)
Derivative of w_ik e_kj - e_ik w_kj wrt. e.
Definition math.cxx:267
BatchTensor mandel_to_full(const BatchTensor &mandel, TorchSize dim)
Convert a BatchTensor from Mandel notation to full notation.
Definition math.cxx:176
constexpr Real invsqrt2
Definition math.h:41
BatchTensor skew_to_full(const BatchTensor &skew, TorchSize dim)
Convert a BatchTensor from skew vector notation to full notation.
Definition math.cxx:196
BatchTensor base_diag_embed(const BatchTensor &a, TorchSize offset, TorchSize d1, TorchSize d2)
Definition math.cxx:248
BatchTensor reduced_to_full(const BatchTensor &reduced, const torch::Tensor &rmap, const torch::Tensor &rfactors, TorchSize dim)
Convert a BatchTensor from reduced notation to full notation.
Definition math.cxx:141
WSR4 d_multiply_and_make_skew_d_first(const SR2 &b)
Derivative of a_ik b_kj - b_ik a_kj wrt a.
Definition math.cxx:294
SR2 skew_and_sym_to_sym(const SR2 &e, const WR2 &w)
Product w_ik e_kj - e_ik w_kj with e SR2 and w WR2.
Definition math.cxx:257
WR2 multiply_and_make_skew(const SR2 &a, const SR2 &b)
Shortcut product a_ik b_kj - b_ik a_kj with both SR2.
Definition math.cxx:285
BatchTensor full_to_reduced(const BatchTensor &full, const torch::Tensor &rmap, const torch::Tensor &rfactors, TorchSize dim)
Generic function to reduce two axes to one with some map.
Definition math.cxx:113
WSR4 d_multiply_and_make_skew_d_second(const SR2 &a)
Derivative of a_ik b_kj - b_ik a_kj wrt b.
Definition math.cxx:303
Derived abs(const Derived &a)
Definition BatchTensorBase.h:457
BatchTensor full_to_mandel(const BatchTensor &full, TorchSize dim)
Convert a BatchTensor from full notation to Mandel notation.
Definition math.cxx:166
BatchTensor jacrev(const BatchTensor &y, const BatchTensor &p)
Use automatic differentiation (AD) to calculate the derivatives w.r.t. to the parameter.
Definition math.cxx:206
constexpr Real sqrt2
Definition math.h:40
BatchTensor full_to_skew(const BatchTensor &full, TorchSize dim)
Convert a BatchTensor from full notation to skew vector notation.
Definition math.cxx:186
TorchSize storage_size(TorchShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:32
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
const torch::TensorOptions default_tensor_options()
Definition types.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
int64_t TorchSize
Definition types.h:33
const torch::TensorOptions default_integer_tensor_options()
We similarly want to have a default integer scalar type for some types of tensors.
Definition types.cxx:42
std::vector< at::indexing::TensorIndex > TorchSlice
Definition types.h:37
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73
A helper class to hold static data of type torch::Tensor.
Definition math.h:64
static const torch::Tensor & skew_to_full_map()
Definition math.cxx:95
static const torch::Tensor & full_to_mandel_factor()
Definition math.cxx:77
static const torch::Tensor & full_to_skew_factor()
Definition math.cxx:101
static const torch::Tensor & full_to_skew_map()
Definition math.cxx:89
ConstantTensors()
Definition math.cxx:33
static ConstantTensors & get()
Definition math.cxx:58
static const torch::Tensor & mandel_to_full_factor()
Definition math.cxx:83
static const torch::Tensor & mandel_to_full_map()
Definition math.cxx:71
static const torch::Tensor & full_to_mandel_map()
Definition math.cxx:65
static const torch::Tensor & skew_to_full_factor()
Definition math.cxx:107