NEML2 1.4.0
Loading...
Searching...
No Matches
math.cxx
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#include "neml2/misc/math.h"
26#include "neml2/misc/error.h"
27#include "neml2/tensors/tensors.h"
28
29#include <torch/autograd.h>
30#include <torch/linalg.h>
31
32namespace neml2
33{
34namespace math
35{
37{
38 _full_to_mandel_map = torch::tensor({0, 4, 8, 5, 2, 1}, default_integer_tensor_options());
39
40 _mandel_to_full_map =
41 torch::tensor({0, 5, 4, 5, 1, 3, 4, 3, 2}, default_integer_tensor_options());
42
43 _full_to_mandel_factor =
44 torch::tensor({1.0, 1.0, 1.0, sqrt2, sqrt2, sqrt2}, default_tensor_options());
45
46 _mandel_to_full_factor =
47 torch::tensor({1.0, invsqrt2, invsqrt2, invsqrt2, 1.0, invsqrt2, invsqrt2, invsqrt2, 1.0},
49
50 _full_to_skew_map = torch::tensor({7, 2, 3}, default_integer_tensor_options());
51
52 _skew_to_full_map = torch::tensor({0, 2, 1, 2, 0, 0, 1, 0, 0}, default_integer_tensor_options());
53
54 _full_to_skew_factor = torch::tensor({1.0, 1.0, 1.0}, default_tensor_options());
55
56 _skew_to_full_factor =
57 torch::tensor({0.0, -1.0, 1.0, 1.0, 0.0, -1.0, -1.0, 1.0, 0.0}, default_tensor_options());
58}
59
62{
63 static ConstantTensors cts;
64 return cts;
65}
66
67const torch::Tensor &
69{
70 return get()._full_to_mandel_map;
71}
72
73const torch::Tensor &
75{
76 return get()._mandel_to_full_map;
77}
78
79const torch::Tensor &
81{
82 return get()._full_to_mandel_factor;
83}
84
85const torch::Tensor &
87{
88 return get()._mandel_to_full_factor;
89}
90
91const torch::Tensor &
93{
94 return get()._full_to_skew_map;
95}
96
97const torch::Tensor &
99{
100 return get()._skew_to_full_map;
101}
102
103const torch::Tensor &
105{
106 return get()._full_to_skew_factor;
107}
108
109const torch::Tensor &
111{
112 return get()._skew_to_full_factor;
113}
114
117 const torch::Tensor & rmap,
118 const torch::Tensor & rfactors,
120{
121 using namespace torch::indexing;
122
123 auto batch_dim = full.batch_dim();
124 auto starting_dim = batch_dim + dim;
125 auto trailing_dim = full.dim() - starting_dim - 2; // 2 comes from the reduced axes (3,3)
126 auto starting_shape = full.sizes().slice(0, starting_dim);
127 auto trailing_shape = full.sizes().slice(starting_dim + 2);
128
130 net.push_back(Ellipsis);
131 net.insert(net.end(), trailing_dim, None);
132 auto map =
133 rmap.index(net).expand(utils::add_shapes(starting_shape, rmap.sizes()[0], trailing_shape));
134 auto factor = rfactors.to(full).index(net);
135
136 return BatchTensor(
137 factor * torch::gather(full.reshape(utils::add_shapes(starting_shape, 9, trailing_shape)),
139 map),
140 batch_dim);
141}
142
145 const torch::Tensor & rmap,
146 const torch::Tensor & rfactors,
148{
149 using namespace torch::indexing;
150
151 auto batch_dim = reduced.batch_dim();
152 auto starting_dim = batch_dim + dim;
153 auto trailing_dim = reduced.dim() - starting_dim - 1; // There's only 1 axis to unsqueeze
154 auto starting_shape = reduced.sizes().slice(0, starting_dim);
155 auto trailing_shape = reduced.sizes().slice(starting_dim + 1);
156
158 net.push_back(Ellipsis);
159 net.insert(net.end(), trailing_dim, None);
160 auto map = rmap.index(net).expand(utils::add_shapes(starting_shape, 9, trailing_shape));
161 auto factor = rfactors.to(reduced).index(net);
162
163 return BatchTensor((factor * torch::gather(reduced, starting_dim, map))
165 batch_dim);
166}
167
170{
171 return full_to_reduced(
172 full,
173 ConstantTensors::full_to_mandel_map().to(full.options().dtype(NEML2_INT_DTYPE)),
174 ConstantTensors::full_to_mandel_factor().to(full.options()),
175 dim);
176}
177
180{
181 return reduced_to_full(
182 mandel,
183 ConstantTensors::mandel_to_full_map().to(mandel.options().dtype(NEML2_INT_DTYPE)),
185 dim);
186}
187
190{
191 return full_to_reduced(
192 full,
193 ConstantTensors::full_to_skew_map().to(full.options().dtype(NEML2_INT_DTYPE)),
194 ConstantTensors::full_to_skew_factor().to(full.options()),
195 dim);
196}
197
200{
201 return reduced_to_full(
202 skew,
203 ConstantTensors::skew_to_full_map().to(skew.options().dtype(NEML2_INT_DTYPE)),
204 ConstantTensors::skew_to_full_factor().to(skew.options()),
205 dim);
206}
207
210{
211 neml_assert(p.batch_sizes() == y.batch_sizes(),
212 "The batch shape of the parameter must be the same as the batch shape "
213 "of the output. However, the batch shape of the parameter is ",
214 p.batch_sizes(),
215 ", and the batch shape of the output is ",
216 y.batch_sizes());
217
218 // flatten y to handle arbitrarily shaped output
219 auto yf = BatchTensor(
220 y.reshape(utils::add_shapes(y.batch_sizes(), utils::storage_size(y.base_sizes()))),
221 y.batch_dim());
222
223 neml_assert_dbg(yf.base_dim() == 1, "Flattened output must be flat.");
224
226 yf.batch_sizes(), utils::add_shapes(yf.base_sizes(), p.base_sizes()), yf.options());
227
228 for (TorchSize i = 0; i < yf.base_sizes()[0]; i++)
229 {
231 v.index_put_({torch::indexing::Ellipsis, i}, 1.0);
232 const auto dyfi_dp = torch::autograd::grad({yf},
233 {p},
234 {v},
235 /*retain_graph=*/true,
236 /*create_graph=*/false,
237 /*allow_unused=*/false)[0];
238 if (dyfi_dp.defined())
239 dyf_dp.base_index_put({i, torch::indexing::Ellipsis}, dyfi_dp);
240 }
241
242 // Reshape the derivative back to the correct shape
243 const auto dy_dp = BatchTensor(
244 dyf_dp.reshape(utils::add_shapes(y.batch_sizes(), y.base_sizes(), p.base_sizes())),
245 y.batch_dim());
246
247 return dy_dp;
248}
249
252{
253 return BatchTensor(
254 torch::diag_embed(
255 a, offset, d1 < 0 ? d1 : d1 + a.batch_dim() + 1, d2 < 0 ? d2 : d2 + a.batch_dim() + 1),
256 a.batch_dim());
257}
258
259SR2
260skew_and_sym_to_sym(const SR2 & e, const WR2 & w)
261{
262 // In NEML we used an unrolled form, I don't think I ever found
263 // a nice direct notation for this one
264 auto E = R2(e);
265 auto W = R2(w);
266 return SR2(W * E - E * W);
267}
268
269SSR4
271{
272 auto I = R2::identity(w.options());
273 auto W = R2(w);
274 return SSR4(R4(torch::einsum("...ia,...jb->...ijab", {W, I}) -
275 torch::einsum("...ia,...bj->...ijab", {I, W})));
276}
277
278SWR4
280{
281 auto I = R2::identity(e.options());
282 auto E = R2(e);
283 return SWR4(R4(torch::einsum("...ia,...bj->...ijab", {I, E}) -
284 torch::einsum("...ia,...jb->...ijab", {E, I})));
285}
286
287WR2
288multiply_and_make_skew(const SR2 & a, const SR2 & b)
289{
290 auto A = R2(a);
291 auto B = R2(b);
292
293 return WR2(A * B - B * A);
294}
295
296WSR4
298{
299 auto I = R2::identity(b.options());
300 auto B = R2(b);
301 return WSR4(R4(torch::einsum("...ia,...bj->...ijab", {I, B}) -
302 torch::einsum("...ia,...jb->...ijab", {B, I})));
303}
304
305WSR4
307{
308 auto I = R2::identity(a.options());
309 auto A = R2(a);
310 return WSR4(R4(torch::einsum("...ia,...jb->...ijab", {A, I}) -
311 torch::einsum("...ia,...bj->...ijab", {I, A})));
312}
313
314namespace linalg
315{
318{
319 neml_assert_dbg(v.base_dim() == 0 || v.base_dim() == 1,
320 "v in vector_norm has base dimension ",
321 v.base_dim(),
322 " instead of 0 or 1.");
323
324 // If the vector is a logical scalar just return its absolute value
325 if (v.base_dim() == 0)
326 return math::abs(v);
327
328 return BatchTensor(torch::linalg::vector_norm(
329 v, /*order=*/2, /*dim=*/-1, /*keepdim=*/false, /*dtype=*/c10::nullopt),
330 v.batch_dim());
331}
332
335{
336 return BatchTensor(torch::linalg::inv(m), m.batch_dim());
337}
338
340solve(const BatchTensor & A, const BatchTensor & B)
341{
342 return BatchTensor(torch::linalg::solve(A, B, /*left=*/true), A.batch_dim());
343}
344
345std::tuple<BatchTensor, BatchTensor>
347{
348 auto [LU, pivots] = torch::linalg_lu_factor(A, pivot);
349 return {BatchTensor(LU, A.batch_dim()), BatchTensor(pivots, A.batch_dim())};
350}
351
354 const BatchTensor & pivots,
355 const BatchTensor & B,
356 bool left,
357 bool adjoint)
358{
359 return BatchTensor(torch::linalg_lu_solve(LU, pivots, B, left, adjoint), B.batch_dim());
360}
361} // namespace linalg
362} // namespace math
363} // namespace neml2
TorchSize batch_dim() const
Return the number of batch dimensions.
Definition BatchTensorBase.cxx:128
static BatchTensor zeros_like(const BatchTensor &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:59
Definition BatchTensor.h:32
static BatchTensor empty(const TorchShapeRef &base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition BatchTensor.cxx:30
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
static R2 identity(const torch::TensorOptions &options=default_tensor_options())
Identity.
Definition R2Base.cxx:170
A basic R2.
Definition R2.h:42
The (logical) full fourth order tensor.
Definition R4.h:43
The (logical) symmetric second order tensor.
Definition SR2.h:46
The (logical) symmetric fourth order tensor, with symmetry in the first two dimensionss as well as in...
Definition SSR4.h:44
The (logical) symmetric fourth order tensor, with symmetry in the first two dimensionss and skew-symm...
Definition SWR4.h:40
A skew rank 2, represented as an axial vector.
Definition WR2.h:43
The (logical) symmetric fourth order tensor, with skew symmetry in the first two dimensionss and symm...
Definition WSR4.h:40
BatchTensor solve(const BatchTensor &A, const BatchTensor &B)
Solve the linear system A X = B.
Definition math.cxx:340
BatchTensor inv(const BatchTensor &m)
Inverse of a square matrix.
Definition math.cxx:334
std::tuple< BatchTensor, BatchTensor > lu_factor(const BatchTensor &A, bool pivot)
Definition math.cxx:346
BatchTensor vector_norm(const BatchTensor &v)
Vector norm of a vector. Falls back to math::abs is v is a Scalar.
Definition math.cxx:317
BatchTensor lu_solve(const BatchTensor &LU, const BatchTensor &pivots, const BatchTensor &B, bool left, bool adjoint)
Definition math.cxx:353
SWR4 d_skew_and_sym_to_sym_d_skew(const SR2 &e)
Derivative of w_ik e_kj - e_ik w_kj wrt. w.
Definition math.cxx:279
SSR4 d_skew_and_sym_to_sym_d_sym(const WR2 &w)
Derivative of w_ik e_kj - e_ik w_kj wrt. e.
Definition math.cxx:270
BatchTensor mandel_to_full(const BatchTensor &mandel, TorchSize dim)
Convert a BatchTensor from Mandel notation to full notation.
Definition math.cxx:179
constexpr Real invsqrt2
Definition math.h:41
BatchTensor skew_to_full(const BatchTensor &skew, TorchSize dim)
Convert a BatchTensor from skew vector notation to full notation.
Definition math.cxx:199
BatchTensor base_diag_embed(const BatchTensor &a, TorchSize offset, TorchSize d1, TorchSize d2)
Definition math.cxx:251
BatchTensor reduced_to_full(const BatchTensor &reduced, const torch::Tensor &rmap, const torch::Tensor &rfactors, TorchSize dim)
Convert a BatchTensor from reduced notation to full notation.
Definition math.cxx:144
WSR4 d_multiply_and_make_skew_d_first(const SR2 &b)
Derivative of a_ik b_kj - b_ik a_kj wrt a.
Definition math.cxx:297
SR2 skew_and_sym_to_sym(const SR2 &e, const WR2 &w)
Product w_ik e_kj - e_ik w_kj with e SR2 and w WR2.
Definition math.cxx:260
WR2 multiply_and_make_skew(const SR2 &a, const SR2 &b)
Shortcut product a_ik b_kj - b_ik a_kj with both SR2.
Definition math.cxx:288
BatchTensor full_to_reduced(const BatchTensor &full, const torch::Tensor &rmap, const torch::Tensor &rfactors, TorchSize dim)
Generic function to reduce two axes to one with some map.
Definition math.cxx:116
WSR4 d_multiply_and_make_skew_d_second(const SR2 &a)
Derivative of a_ik b_kj - b_ik a_kj wrt b.
Definition math.cxx:306
Derived abs(const Derived &a)
Definition BatchTensorBase.h:457
BatchTensor full_to_mandel(const BatchTensor &full, TorchSize dim)
Convert a BatchTensor from full notation to Mandel notation.
Definition math.cxx:169
BatchTensor jacrev(const BatchTensor &y, const BatchTensor &p)
Use automatic differentiation (AD) to calculate the derivatives w.r.t. to the parameter.
Definition math.cxx:209
constexpr Real sqrt2
Definition math.h:40
BatchTensor full_to_skew(const BatchTensor &full, TorchSize dim)
Convert a BatchTensor from full notation to skew vector notation.
Definition math.cxx:189
TorchSize storage_size(TorchShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:32
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
const torch::TensorOptions default_tensor_options()
Definition types.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
int64_t TorchSize
Definition types.h:35
const torch::TensorOptions default_integer_tensor_options()
We similarly want to have a default integer scalar type for some types of tensors.
Definition types.cxx:42
std::vector< at::indexing::TensorIndex > TorchSlice
Definition types.h:39
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73
A helper class to hold static data of type torch::Tensor.
Definition math.h:64
static const torch::Tensor & skew_to_full_map()
Definition math.cxx:98
static const torch::Tensor & full_to_mandel_factor()
Definition math.cxx:80
static const torch::Tensor & full_to_skew_factor()
Definition math.cxx:104
static const torch::Tensor & full_to_skew_map()
Definition math.cxx:92
ConstantTensors()
Definition math.cxx:36
static ConstantTensors & get()
Definition math.cxx:61
static const torch::Tensor & mandel_to_full_factor()
Definition math.cxx:86
static const torch::Tensor & mandel_to_full_map()
Definition math.cxx:74
static const torch::Tensor & full_to_mandel_map()
Definition math.cxx:68
static const torch::Tensor & skew_to_full_factor()
Definition math.cxx:110