NEML2 1.4.0
Loading...
Searching...
No Matches
R2Base.cxx
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#include "neml2/misc/math.h"
26#include "neml2/tensors/R2Base.h"
27#include "neml2/tensors/R2.h"
28#include "neml2/tensors/Scalar.h"
29#include "neml2/tensors/Vec.h"
30#include "neml2/tensors/SR2.h"
31#include "neml2/tensors/R3.h"
32#include "neml2/tensors/R4.h"
33#include "neml2/tensors/Rot.h"
34#include "neml2/tensors/WR2.h"
35
36namespace neml2
37{
38template <class Derived>
39Derived
40R2Base<Derived>::fill(const Real & a, const torch::TensorOptions & options)
41{
42 return R2Base<Derived>::fill(Scalar(a, options));
43}
44
45template <class Derived>
48{
49 auto zero = torch::zeros_like(a);
50 return Derived(torch::stack({torch::stack({a, zero, zero}, -1),
51 torch::stack({zero, a, zero}, -1),
52 torch::stack({zero, zero, a}, -1)},
53 -2),
54 a.batch_dim());
55}
56
57template <class Derived>
60 const Real & a22,
61 const Real & a33,
62 const torch::TensorOptions & options)
64 return R2Base<Derived>::fill(Scalar(a11, options), Scalar(a22, options), Scalar(a33, options));
65}
66
67template <class Derived>
71 auto zero = torch::zeros_like(a11);
72 return Derived(torch::stack({torch::stack({a11, zero, zero}, -1),
73 torch::stack({zero, a22, zero}, -1),
74 torch::stack({zero, zero, a33}, -1)},
75 -2),
76 a11.batch_dim());
78
79template <class Derived>
82 const Real & a22,
83 const Real & a33,
84 const Real & a23,
85 const Real & a13,
86 const Real & a12,
87 const torch::TensorOptions & options)
88{
89 return R2Base<Derived>::fill(Scalar(a11, options),
90 Scalar(a22, options),
91 Scalar(a33, options),
92 Scalar(a23, options),
93 Scalar(a13, options),
94 Scalar(a12, options));
96
97template <class Derived>
100 const Scalar & a22,
101 const Scalar & a33,
102 const Scalar & a23,
103 const Scalar & a13,
104 const Scalar & a12)
106 return Derived(torch::stack({torch::stack({a11, a12, a13}, -1),
107 torch::stack({a12, a22, a23}, -1),
108 torch::stack({a13, a23, a33}, -1)},
109 -2),
110 a11.batch_dim());
112
113template <class Derived>
116 const Real & a12,
117 const Real & a13,
118 const Real & a21,
119 const Real & a22,
120 const Real & a23,
121 const Real & a31,
122 const Real & a32,
123 const Real & a33,
124 const torch::TensorOptions & options)
125{
127 Scalar(a12, options),
128 Scalar(a13, options),
129 Scalar(a21, options),
130 Scalar(a22, options),
131 Scalar(a23, options),
132 Scalar(a31, options),
133 Scalar(a32, options),
134 Scalar(a33, options));
135}
136
137template <class Derived>
140 const Scalar & a12,
141 const Scalar & a13,
142 const Scalar & a21,
143 const Scalar & a22,
144 const Scalar & a23,
145 const Scalar & a31,
146 const Scalar & a32,
147 const Scalar & a33)
148{
149 return Derived(torch::stack({torch::stack({a11, a12, a13}, -1),
150 torch::stack({a21, a22, a23}, -1),
151 torch::stack({a31, a32, a33}, -1)},
152 -2),
153 a11.batch_dim());
154}
155
156template <class Derived>
159{
160 const auto z = torch::zeros_like(v(0));
161 return Derived(torch::stack({torch::stack({z, -v(2), v(1)}, -1),
162 torch::stack({v(2), z, -v(0)}, -1),
163 torch::stack({-v(1), v(0), z}, -1)},
164 -2),
165 v.batch_dim());
166}
167
168template <class Derived>
170R2Base<Derived>::identity(const torch::TensorOptions & options)
171{
172 return Derived(torch::eye(3, options), 0);
173}
174
175template <class Derived>
178{
179 return rotate(r.euler_rodrigues());
180}
181
182template <class Derived>
185{
186 return R * R2(*this) * R.transpose();
187}
188
189template <class Derived>
190R3
192{
193 Derived R = r.euler_rodrigues();
194 R3 F = r.deuler_rodrigues();
195
196 return R3(torch::einsum("...itl,...tm,...jm", {F, *this, R}) +
197 torch::einsum("...ik,...kt,...jtl", {R, *this, F}),
198 broadcast_batch_dim(*this, F, R));
199}
200
201template <class Derived>
202R4
204{
205 auto I = R2::identity(R.options());
206 return torch::einsum("...ik,...jl", {I, R * this->transpose()}) +
207 torch::einsum("...jk,...il", {I, R * R2(*this)});
208}
209
210template <class Derived>
211Scalar
216
217template <class Derived>
220{
221 return torch::Tensor::inverse();
222}
223
224template <class Derived>
230
231template <class Derived1, class Derived2, typename, typename>
232Vec
233operator*(const Derived1 & A, const Derived2 & b)
234{
236 return Vec(torch::einsum("...ik,...k", {A, b}), broadcast_batch_dim(A, b));
237}
238
239template <class Derived1, class Derived2, typename, typename>
240R2
241operator*(const Derived1 & A, const Derived2 & B)
242{
244 return R2(torch::einsum("...ik,...kj", {A, B}), broadcast_batch_dim(A, B));
245}
246
247// template instantiation
248
249// derived classes
250template class R2Base<R2>;
251
252// products
253template Vec operator*(const R2 & A, const Vec & b);
254template R2 operator*(const R2 & A, const R2 & B);
255} // namespace neml2
BatchTensor base_transpose(TorchSize d1, TorchSize d2) const
Transpose two base dimensions.
Definition BatchTensorBase.cxx:299
BatchTensor base_index(const TorchSlice &indices) const
Return an index sliced on the base dimensions.
Definition BatchTensorBase.cxx:193
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
CrossRef()=default
Derived rotate(const Rot &r) const
Rotate using a Rodrigues vector.
Definition R2Base.cxx:177
R3 drotate(const Rot &r) const
Derivative of the rotated tensor w.r.t. the Rodrigues vector.
Definition R2Base.cxx:191
static Derived fill(const Real &a, const torch::TensorOptions &options=default_tensor_options())
Fill the diagonals with a11 = a22 = a33 = a.
Definition R2Base.cxx:40
static Derived identity(const torch::TensorOptions &options=default_tensor_options())
Identity.
Definition R2Base.cxx:170
Derived transpose() const
transpose
Definition R2Base.cxx:226
Derived inverse() const
Inversion.
Definition R2Base.cxx:219
Scalar operator()(TorchSize i, TorchSize j) const
Accessor.
Definition R2Base.cxx:212
static Derived skew(const Vec &v)
Skew matrix from Vec.
Definition R2Base.cxx:158
A basic R2.
Definition R2.h:42
The (logical) full third order tensor.
Definition R3.h:41
The (logical) full fourth order tensor.
Definition R4.h:43
Rotation stored as modified Rodrigues parameters.
Definition Rot.h:49
The (logical) scalar.
Definition Scalar.h:38
The (logical) vector.
Definition Vec.h:42
Definition CrossRef.cxx:32
BatchTensor operator*(const BatchTensor &a, const BatchTensor &b)
Definition BatchTensor.cxx:153
int64_t TorchSize
Definition types.h:35
void neml_assert_batch_broadcastable_dbg(T &&...)
A helper function to assert that (in Debug mode) all tensors are batch-broadcastable.
TorchSize broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
double Real
Definition types.h:33
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.