27#include "neml2/tensors/Tensor.h"
38constexpr Real eps = std::numeric_limits<at::scalar_value_type<Real>::type>::epsilon();
44constexpr Size mandel_index[6][2] = {{0, 0}, {1, 1}, {2, 2}, {1, 2}, {0, 2}, {0, 1}};
47constexpr Real skew_factor[3][3] = {{0.0, -1.0, 1.0}, {1.0, 0.0, -1.0}, {-1.0, 1.0, 0.0}};
52 return i < 3 ? 1.0 :
sqrt2;
80 torch::Tensor _full_to_mandel_map;
81 torch::Tensor _mandel_to_full_map;
82 torch::Tensor _full_to_mandel_factor;
83 torch::Tensor _mandel_to_full_factor;
84 torch::Tensor _full_to_skew_map;
85 torch::Tensor _skew_to_full_map;
86 torch::Tensor _full_to_skew_factor;
87 torch::Tensor _skew_to_full_factor;
112 const torch::Tensor &
rmap,
128 const torch::Tensor &
rmap,
232template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
238 auto d2 = d >= 0 ? d : d -
tensors.begin()->base_dim();
242template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
248 auto d2 = d < 0 ? d : d +
tensors.begin()->batch_dim();
252template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
258 auto d2 = d >= 0 ? d : d -
tensors.begin()->base_dim();
262template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
268 auto d2 = d < 0 ? d : d +
tensors.begin()->batch_dim();
272template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
276 neml_assert_dbg(a.batch_dim() > 0,
"Must have a batch dimension to sum along");
277 auto d2 = d >= 0 ? d : d - a.base_dim();
278 return T(torch::sum(a,
d2), a.batch_dim() - 1);
281template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
285 auto d2 = d < 0 ? d : d + a.batch_dim();
286 return T(torch::sum(a,
d2), a.batch_dim());
289template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
293 return T(torch::pow(a,
n), a.batch_dim());
300template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
304 return T(torch::sign(a), a.batch_dim());
307template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
311 return T(torch::cosh(a), a.batch_dim());
314template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
318 return T(torch::sinh(a), a.batch_dim());
321template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
325 return T(torch::tanh(a), a.batch_dim());
328template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
342template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
346 return (
sign(a) + 1.0) / 2.0;
349template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
353 return T(torch::Tensor(a) * torch::Tensor(
heaviside(a)), a.batch_dim());
356template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
363template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
367 return T(torch::sqrt(a), a.batch_dim());
370template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
374 return T(torch::exp(a), a.batch_dim());
377template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
381 return T(torch::abs(a), a.batch_dim());
384template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
388 return T(torch::diff(a,
n, dim), a.batch_dim());
391template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
395 return T(torch::diag_embed(
400template <
class T,
typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
404 return T(torch::log(a), a.batch_dim());
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
The (logical) symmetric second order tensor.
Definition SR2.h:46
The (logical) symmetric fourth order tensor, with symmetry in the first two dimensionss as well as in...
Definition SSR4.h:44
The (logical) symmetric fourth order tensor, with symmetry in the first two dimensionss and skew-symm...
Definition SWR4.h:40
A skew rank 2, represented as an axial vector.
Definition WR2.h:43
The (logical) symmetric fourth order tensor, with skew symmetry in the first two dimensionss and symm...
Definition WSR4.h:40
Tensor solve(const Tensor &A, const Tensor &B)
Solve the linear system A X = B.
Definition math.cxx:347
Tensor vector_norm(const Tensor &v)
Vector norm of a vector. Falls back to math::abs is v is a Scalar.
Definition math.cxx:324
Tensor lu_solve(const Tensor &LU, const Tensor &pivots, const Tensor &B, bool left, bool adjoint)
Definition math.cxx:360
Tensor inv(const Tensor &m)
Inverse of a square matrix.
Definition math.cxx:341
std::tuple< Tensor, Tensor > lu_factor(const Tensor &A, bool pivot)
Definition math.cxx:353
Tensor full_to_reduced(const Tensor &full, const torch::Tensor &rmap, const torch::Tensor &rfactors, Size dim)
Generic function to reduce two axes to one with some map.
Definition math.cxx:113
T dmacaulay(const T &a)
Definition math.h:358
Tensor reduced_to_full(const Tensor &reduced, const torch::Tensor &rmap, const torch::Tensor &rfactors, Size dim)
Convert a Tensor from reduced notation to full notation.
Definition math.cxx:139
constexpr Real skew_factor[3][3]
Definition math.h:47
T batch_sum(const T &a, Size d=0)
Definition math.h:274
constexpr Size mandel_index[6][2]
Definition math.h:44
T cosh(const T &a)
Definition math.h:309
constexpr Real eps
Definition math.h:38
SWR4 d_skew_and_sym_to_sym_d_skew(const SR2 &e)
Derivative of w_ik e_kj - e_ik w_kj wrt. w.
Definition math.cxx:273
SSR4 d_skew_and_sym_to_sym_d_sym(const WR2 &w)
Derivative of w_ik e_kj - e_ik w_kj wrt. e.
Definition math.cxx:264
constexpr Size skew_reverse_index[3][3]
Definition math.h:46
T heaviside(const T &a)
Definition math.h:344
Tensor full_to_skew(const Tensor &full, Size dim)
Convert a Tensor from full notation to skew vector notation.
Definition math.cxx:182
constexpr Real invsqrt2
Definition math.h:41
Tensor mandel_to_full(const Tensor &mandel, Size dim)
Convert a Tensor from Mandel notation to full notation.
Definition math.cxx:172
Tensor base_diag_embed(const Tensor &a, Size offset, Size d1, Size d2)
Definition math.cxx:245
T exp(const T &a)
Definition math.h:372
neml2::Tensor base_cat(const std::vector< T > &tensors, Size d=0)
Definition math.h:244
T log(const T &a)
Definition math.h:402
WSR4 d_multiply_and_make_skew_d_first(const SR2 &b)
Derivative of a_ik b_kj - b_ik a_kj wrt a.
Definition math.cxx:291
SR2 skew_and_sym_to_sym(const SR2 &e, const WR2 &w)
Product w_ik e_kj - e_ik w_kj with e SR2 and w WR2.
Definition math.cxx:254
T sinh(const T &a)
Definition math.h:316
T tanh(const T &a)
Definition math.h:323
constexpr Size mandel_reverse_index[3][3]
Definition math.h:43
WR2 multiply_and_make_skew(const SR2 &a, const SR2 &b)
Shortcut product a_ik b_kj - b_ik a_kj with both SR2.
Definition math.cxx:282
T batch_stack(const std::vector< T > &tensors, Size d=0)
Definition math.h:254
T sqrt(const T &a)
Definition math.h:365
constexpr Real mandel_factor(Size i)
Definition math.h:50
neml2::Tensor base_stack(const std::vector< T > &tensors, Size d=0)
Definition math.h:264
T diff(const T &a, Size n=1, Size dim=-1)
Definition math.h:386
T batch_diag_embed(const T &a, Size offset=0, Size d1=-2, Size d2=-1)
Definition math.h:393
T abs(const T &a)
Definition math.h:379
T batch_cat(const std::vector< T > &tensors, Size d=0)
Definition math.h:234
WSR4 d_multiply_and_make_skew_d_second(const SR2 &a)
Derivative of a_ik b_kj - b_ik a_kj wrt b.
Definition math.cxx:300
T base_sum(const T &a, Size d=0)
Definition math.h:283
Tensor pow(const Real &a, const Tensor &n)
Definition math.cxx:309
Tensor skew_to_full(const Tensor &skew, Size dim)
Convert a Tensor from skew vector notation to full notation.
Definition math.cxx:192
T where(const torch::Tensor &condition, const T &a, const T &b)
Definition math.h:330
Tensor full_to_mandel(const Tensor &full, Size dim)
Convert a Tensor from full notation to Mandel notation.
Definition math.cxx:162
T macaulay(const T &a)
Definition math.h:351
T sign(const T &a)
Definition math.h:302
constexpr Real sqrt2
Definition math.h:40
Tensor jacrev(const Tensor &y, const Tensor &p)
Use automatic differentiation (AD) to calculate the derivatives w.r.t. to the parameter.
Definition math.cxx:202
Definition CrossRef.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:76
Size broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
double Real
Definition types.h:31
int64_t Size
Definition types.h:33
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
A helper class to hold static data of type torch::Tensor.
Definition math.h:64
static const torch::Tensor & skew_to_full_map()
Definition math.cxx:95
static const torch::Tensor & full_to_mandel_factor()
Definition math.cxx:77
static const torch::Tensor & full_to_skew_factor()
Definition math.cxx:101
static const torch::Tensor & full_to_skew_map()
Definition math.cxx:89
ConstantTensors()
Definition math.cxx:36
static ConstantTensors & get()
Definition math.cxx:58
static const torch::Tensor & mandel_to_full_factor()
Definition math.cxx:83
static const torch::Tensor & mandel_to_full_map()
Definition math.cxx:71
static const torch::Tensor & full_to_mandel_map()
Definition math.cxx:65
static const torch::Tensor & skew_to_full_factor()
Definition math.cxx:107