25#include "neml2/misc/math.h"
26#include "neml2/misc/error.h"
27#include "neml2/tensors/tensors.h"
40 _full_to_mandel_factor =
43 _mandel_to_full_factor =
53 _skew_to_full_factor =
67 return get()._full_to_mandel_map;
73 return get()._mandel_to_full_map;
79 return get()._full_to_mandel_factor;
85 return get()._mandel_to_full_factor;
91 return get()._full_to_skew_map;
97 return get()._skew_to_full_map;
103 return get()._full_to_skew_factor;
109 return get()._skew_to_full_factor;
114 const torch::Tensor &
rmap,
118 using namespace torch::indexing;
142 const torch::Tensor &
rmap,
146 using namespace torch::indexing;
148 auto batch_dim =
reduced.batch_dim();
209 "The batch shape of the parameter must be the same as the batch shape "
210 "of the output. However, the batch shape of the parameter is ",
212 ", and the batch shape of the output is ",
228 v.index_put_({torch::indexing::Ellipsis,
i}, 1.0);
229 const auto dyfi_dp = torch::autograd::grad({
yf},
252 a,
offset,
d1 < 0 ?
d1 :
d1 + a.batch_dim() + 1,
d2 < 0 ?
d2 :
d2 + a.batch_dim() + 1),
263 return SR2(W *
E -
E * W);
271 return SSR4(
R4(torch::einsum(
"...ia,...jb->...ijab", {W,
I}) -
272 torch::einsum(
"...ia,...bj->...ijab", {
I, W})));
280 return SWR4(
R4(torch::einsum(
"...ia,...bj->...ijab", {
I,
E}) -
281 torch::einsum(
"...ia,...jb->...ijab", {
E,
I})));
290 return WR2(A *
B -
B * A);
298 return WSR4(
R4(torch::einsum(
"...ia,...bj->...ijab", {
I,
B}) -
299 torch::einsum(
"...ia,...jb->...ijab", {
B,
I})));
307 return WSR4(
R4(torch::einsum(
"...ia,...jb->...ijab", {A,
I}) -
308 torch::einsum(
"...ia,...bj->...ijab", {
I, A})));
317 "v in vector_norm has base dimension ",
319 " instead of 0 or 1.");
322 if (
v.base_dim() == 0)
326 v, 2, -1,
false, c10::nullopt),
336std::tuple<BatchTensor, BatchTensor>
TorchSize batch_dim() const
Return the number of batch dimensions.
Definition BatchTensorBase.cxx:128
static BatchTensor zeros_like(const BatchTensor &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:59
Definition BatchTensor.h:32
static BatchTensor empty(const TorchShapeRef &base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition BatchTensor.cxx:30
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
static R2 identity(const torch::TensorOptions &options=default_tensor_options())
Identity.
Definition R2Base.cxx:170
A basic R2.
Definition R2.h:42
The (logical) full fourth order tensor.
Definition R4.h:43
The (logical) symmetric second order tensor.
Definition SR2.h:46
The (logical) symmetric fourth order tensor, with symmetry in the first two dimensionss as well as in...
Definition SSR4.h:44
The (logical) symmetric fourth order tensor, with symmetry in the first two dimensionss and skew-symm...
Definition SWR4.h:40
A skew rank 2, represented as an axial vector.
Definition WR2.h:43
The (logical) symmetric fourth order tensor, with skew symmetry in the first two dimensionss and symm...
Definition WSR4.h:40
BatchTensor solve(const BatchTensor &A, const BatchTensor &B)
Solve the linear system A X = B.
Definition math.cxx:331
std::tuple< BatchTensor, BatchTensor > lu_factor(const BatchTensor &A, bool pivot)
Definition math.cxx:337
BatchTensor vector_norm(const BatchTensor &v)
Vector norm of a vector. Falls back to math::abs is v is a Scalar.
Definition math.cxx:314
BatchTensor lu_solve(const BatchTensor &LU, const BatchTensor &pivots, const BatchTensor &B, bool left, bool adjoint)
Definition math.cxx:344
SWR4 d_skew_and_sym_to_sym_d_skew(const SR2 &e)
Derivative of w_ik e_kj - e_ik w_kj wrt. w.
Definition math.cxx:276
SSR4 d_skew_and_sym_to_sym_d_sym(const WR2 &w)
Derivative of w_ik e_kj - e_ik w_kj wrt. e.
Definition math.cxx:267
BatchTensor mandel_to_full(const BatchTensor &mandel, TorchSize dim)
Convert a BatchTensor from Mandel notation to full notation.
Definition math.cxx:176
constexpr Real invsqrt2
Definition math.h:41
BatchTensor skew_to_full(const BatchTensor &skew, TorchSize dim)
Convert a BatchTensor from skew vector notation to full notation.
Definition math.cxx:196
BatchTensor base_diag_embed(const BatchTensor &a, TorchSize offset, TorchSize d1, TorchSize d2)
Definition math.cxx:248
BatchTensor reduced_to_full(const BatchTensor &reduced, const torch::Tensor &rmap, const torch::Tensor &rfactors, TorchSize dim)
Convert a BatchTensor from reduced notation to full notation.
Definition math.cxx:141
WSR4 d_multiply_and_make_skew_d_first(const SR2 &b)
Derivative of a_ik b_kj - b_ik a_kj wrt a.
Definition math.cxx:294
SR2 skew_and_sym_to_sym(const SR2 &e, const WR2 &w)
Product w_ik e_kj - e_ik w_kj with e SR2 and w WR2.
Definition math.cxx:257
WR2 multiply_and_make_skew(const SR2 &a, const SR2 &b)
Shortcut product a_ik b_kj - b_ik a_kj with both SR2.
Definition math.cxx:285
BatchTensor full_to_reduced(const BatchTensor &full, const torch::Tensor &rmap, const torch::Tensor &rfactors, TorchSize dim)
Generic function to reduce two axes to one with some map.
Definition math.cxx:113
WSR4 d_multiply_and_make_skew_d_second(const SR2 &a)
Derivative of a_ik b_kj - b_ik a_kj wrt b.
Definition math.cxx:303
Derived abs(const Derived &a)
Definition BatchTensorBase.h:457
BatchTensor full_to_mandel(const BatchTensor &full, TorchSize dim)
Convert a BatchTensor from full notation to Mandel notation.
Definition math.cxx:166
BatchTensor jacrev(const BatchTensor &y, const BatchTensor &p)
Use automatic differentiation (AD) to calculate the derivatives w.r.t. to the parameter.
Definition math.cxx:206
constexpr Real sqrt2
Definition math.h:40
BatchTensor full_to_skew(const BatchTensor &full, TorchSize dim)
Convert a BatchTensor from full notation to skew vector notation.
Definition math.cxx:186
TorchSize storage_size(TorchShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:32
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
const torch::TensorOptions default_tensor_options()
Definition types.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
int64_t TorchSize
Definition types.h:33
const torch::TensorOptions default_integer_tensor_options()
We similarly want to have a default integer scalar type for some types of tensors.
Definition types.cxx:42
std::vector< at::indexing::TensorIndex > TorchSlice
Definition types.h:37
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73
A helper class to hold static data of type torch::Tensor.
Definition math.h:64
static const torch::Tensor & skew_to_full_map()
Definition math.cxx:95
static const torch::Tensor & full_to_mandel_factor()
Definition math.cxx:77
static const torch::Tensor & full_to_skew_factor()
Definition math.cxx:101
static const torch::Tensor & full_to_skew_map()
Definition math.cxx:89
ConstantTensors()
Definition math.cxx:33
static ConstantTensors & get()
Definition math.cxx:58
static const torch::Tensor & mandel_to_full_factor()
Definition math.cxx:83
static const torch::Tensor & mandel_to_full_map()
Definition math.cxx:71
static const torch::Tensor & full_to_mandel_map()
Definition math.cxx:65
static const torch::Tensor & skew_to_full_factor()
Definition math.cxx:107