25#include "neml2/misc/math.h"
26#include "neml2/tensors/R2Base.h"
27#include "neml2/tensors/R2.h"
28#include "neml2/tensors/Scalar.h"
29#include "neml2/tensors/Vec.h"
30#include "neml2/tensors/SR2.h"
31#include "neml2/tensors/R3.h"
32#include "neml2/tensors/R4.h"
33#include "neml2/tensors/Rot.h"
34#include "neml2/tensors/WR2.h"
38template <
class Derived>
45template <
class Derived>
49 auto zero = torch::zeros_like(a);
57template <
class Derived>
62 const torch::TensorOptions & options)
67template <
class Derived>
71 auto zero = torch::zeros_like(
a11);
79template <
class Derived>
87 const torch::TensorOptions & options)
97template <
class Derived>
113template <
class Derived>
124 const torch::TensorOptions & options)
137template <
class Derived>
156template <
class Derived>
160 const auto z = torch::zeros_like(
v(0));
161 return Derived(torch::stack({torch::stack({z, -
v(2),
v(1)}, -1),
162 torch::stack({
v(2), z, -
v(0)}, -1),
163 torch::stack({-
v(1),
v(0), z}, -1)},
168template <
class Derived>
172 return Derived(torch::eye(3, options), 0);
175template <
class Derived>
179 return rotate(
r.euler_rodrigues());
182template <
class Derived>
186 return R *
R2(*
this) *
R.transpose();
189template <
class Derived>
194 R3 F =
r.deuler_rodrigues();
196 return R3(torch::einsum(
"...itl,...tm,...jm", {
F, *
this,
R}) +
197 torch::einsum(
"...ik,...kt,...jtl", {
R, *
this,
F}),
201template <
class Derived>
206 return torch::einsum(
"...ik,...jl", {
I,
R * this->transpose()}) +
207 torch::einsum(
"...jk,...il", {
I,
R *
R2(*
this)});
210template <
class Derived>
217template <
class Derived>
221 return torch::Tensor::inverse();
224template <
class Derived>
231template <
class Derived1,
class Derived2,
typename,
typename>
239template <
class Derived1,
class Derived2,
typename,
typename>
250template class R2Base<R2>;
BatchTensor base_transpose(TorchSize d1, TorchSize d2) const
Transpose two base dimensions.
Definition BatchTensorBase.cxx:299
BatchTensor base_index(const TorchSlice &indices) const
Return an index sliced on the base dimensions.
Definition BatchTensorBase.cxx:193
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
Derived rotate(const Rot &r) const
Rotate using a Rodrigues vector.
Definition R2Base.cxx:177
R3 drotate(const Rot &r) const
Derivative of the rotated tensor w.r.t. the Rodrigues vector.
Definition R2Base.cxx:191
static Derived fill(const Real &a, const torch::TensorOptions &options=default_tensor_options())
Fill the diagonals with a11 = a22 = a33 = a.
Definition R2Base.cxx:40
static Derived identity(const torch::TensorOptions &options=default_tensor_options())
Identity.
Definition R2Base.cxx:170
Derived transpose() const
transpose
Definition R2Base.cxx:226
Derived inverse() const
Inversion.
Definition R2Base.cxx:219
Scalar operator()(TorchSize i, TorchSize j) const
Accessor.
Definition R2Base.cxx:212
static Derived skew(const Vec &v)
Skew matrix from Vec.
Definition R2Base.cxx:158
A basic R2.
Definition R2.h:42
The (logical) full third order tensor.
Definition R3.h:41
The (logical) full fourth order tensor.
Definition R4.h:43
Rotation stored as modified Rodrigues parameters.
Definition Rot.h:49
The (logical) scalar.
Definition Scalar.h:38
The (logical) vector.
Definition Vec.h:42
Definition CrossRef.cxx:32
BatchTensor operator*(const BatchTensor &a, const BatchTensor &b)
Definition BatchTensor.cxx:153
int64_t TorchSize
Definition types.h:35
void neml_assert_batch_broadcastable_dbg(T &&...)
A helper function to assert that (in Debug mode) all tensors are batch-broadcastable.
TorchSize broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
double Real
Definition types.h:33
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.