25#include "neml2/tensors/user_tensors/Orientation.h"
27#include "neml2/tensors/Quaternion.h"
29using namespace torch::indexing;
40 options.
doc() =
"An orientation, internally defined as a set of Modified Rodrigues parameters "
41 "given by \\f$ r = n \\tan{\\frac{\\theta}{4}} \\f$ with \\f$ n \\f$ the axis of "
42 "rotation and \\f$ \\theta \\f$ the rotation angle about that axis. However, "
43 "this class provides a variety of ways to define the orientation in terms of "
44 "other, more common representations.";
46 options.
set<std::string>(
"input_type") =
"euler_angles";
47 options.
set(
"input_type").doc() =
48 "The method used to define the angles, 'euler_angles' or 'random'";
50 options.
set<std::string>(
"angle_convention") =
"kocks";
51 options.
set(
"angle_convention").doc() =
"Euler angle convention, 'Kocks', 'Roe', or 'Bunge'";
53 options.
set<std::string>(
"angle_type") =
"degrees";
54 options.
set(
"angle_type").doc() =
"Type of angles, either 'degrees' or 'radians'";
56 options.
set<std::vector<Real>>(
"values") = {};
57 options.
set(
"values").doc() =
"Input Euler angles, as a flattened n-by-3 matrix";
59 options.
set<
bool>(
"normalize") =
false;
60 options.
set(
"normalize").doc() =
61 "If true do a shadow parameter replacement of the underlying MRP representation to move the "
62 "inputs farther away from the singularity";
65 options.
set(
"random_seed").doc() =
"Random seed for random angle generation";
67 options.
set<
unsigned int>(
"quantity") = 1;
68 options.
set(
"quantity").doc() =
"Number (batch size) of random orientations";
80Orientation::fill(
const OptionSet & options)
const
82 std::string
input_type = options.
get<std::string>(
"input_type");
87 R = expand_as_needed(fill_euler_angles(torch::tensor(options.
get<std::vector<Real>>(
"values"),
89 options.
get<std::string>(
"angle_convention"),
90 options.
get<std::string>(
"angle_type")),
91 options.
get<
unsigned int>(
"quantity"));
93 else if (input_type ==
"random")
95 R = fill_random(options.
get<
unsigned int>(
"quantity"), options.
get<
TorchSize>(
"random_seed"));
98 throw NEMLException(
"Unknown Orientation input_type " + input_type);
100 if (options.
get<
bool>(
"normalize"))
101 return math::where((R.norm_sq() < 1.0).unsqueeze(-1), R, R.shadow());
107Orientation::fill_euler_angles(
const torch::Tensor & vals,
108 std::string angle_convention,
109 std::string angle_type)
const
112 (torch::numel(vals) % 3) == 0,
113 "Orientation input values should have length divisable by 3 for input type 'euler_angles'");
114 auto ten = vals.reshape({-1, 3});
116 if (angle_type ==
"degrees")
117 ten = torch::deg2rad(ten);
120 "Orientation angle_type must be either 'degrees' or 'radians'");
122 if (angle_convention ==
"bunge")
124 ten.index_put_({Ellipsis, 0}, torch::fmod(ten.index({Ellipsis, 0}) - M_PI / 2.0, 2.0 * M_PI));
125 ten.index_put_({Ellipsis, 1}, torch::fmod(ten.index({Ellipsis, 1}), M_PI));
126 ten.index_put_({Ellipsis, 2}, torch::fmod(M_PI / 2.0 - ten.index({Ellipsis, 2}), 2.0 * M_PI));
128 else if (angle_convention ==
"roe")
130 ten.index_put_({Ellipsis, 2}, M_PI - ten.index({Ellipsis, 2}));
134 "Unknown Orientation angle_convention " + angle_convention);
137 auto M = torch::zeros({ten.sizes()[0], 3, 3}, vals.options());
138 auto a = ten.index({Ellipsis, 0});
139 auto b = ten.index({Ellipsis, 1});
140 auto c = ten.index({Ellipsis, 2});
142 M.index_put_({Ellipsis, 0, 0},
143 -torch::sin(c) * torch::sin(a) - torch::cos(c) * torch::cos(a) * torch::cos(b));
144 M.index_put_({Ellipsis, 0, 1},
145 torch::sin(c) * torch::cos(a) - torch::cos(c) * torch::sin(a) * torch::cos(b));
146 M.index_put_({Ellipsis, 0, 2}, torch::cos(c) * torch::sin(b));
147 M.index_put_({Ellipsis, 1, 0},
148 torch::cos(c) * torch::sin(a) - torch::sin(c) * torch::cos(a) * torch::cos(b));
149 M.index_put_({Ellipsis, 1, 1},
150 -torch::cos(c) * torch::cos(a) - torch::sin(c) * torch::sin(a) * torch::cos(b));
151 M.index_put_({Ellipsis, 1, 2}, torch::sin(c) * torch::sin(b));
152 M.index_put_({Ellipsis, 2, 0}, torch::cos(a) * torch::sin(b));
153 M.index_put_({Ellipsis, 2, 1}, torch::sin(a) * torch::sin(b));
154 M.index_put_({Ellipsis, 2, 2}, torch::cos(b));
157 return fill_matrix(R2(M, 1));
161Orientation::fill_matrix(
const R2 & M)
const
164 auto trace = M.index({Ellipsis, 0, 0}) + M.index({Ellipsis, 1, 1}) + M.index({Ellipsis, 2, 2});
165 auto theta = torch::acos((trace - 1.0) / 2.0);
168 auto scale = torch::tan(theta / 2.0) / (2.0 * torch::sin(theta));
169 scale.index_put_({theta == 0}, 0.0);
170 auto rx = (M.index({Ellipsis, 2, 1}) - M.index({Ellipsis, 1, 2})) * scale;
171 auto ry = (M.index({Ellipsis, 0, 2}) - M.index({Ellipsis, 2, 0})) * scale;
172 auto rz = (M.index({Ellipsis, 1, 0}) - M.index({Ellipsis, 0, 1})) * scale;
174 return fill_rodrigues(rx, ry, rz);
178Orientation::fill_rodrigues(
const Scalar & rx,
const Scalar & ry,
const Scalar & rz)
const
181 auto ns = rx * rx + ry * ry + rz * rz;
182 auto f = torch::sqrt(torch::Tensor(ns) + torch::tensor(1.0, ns.dtype())) +
183 torch::tensor(1.0, ns.dtype());
186 return Rot(torch::stack({rx / f, ry / f, rz / f}, 1), 1);
190Orientation::fill_random(
unsigned int n, TorchSize random_seed)
const
192 if (random_seed >= 0)
193 torch::manual_seed(random_seed);
198 auto w = torch::sqrt(1.0 - u0) * torch::sin(2.0 * M_PI * u1);
199 auto x = torch::sqrt(1.0 - u0) * torch::cos(2.0 * M_PI * u1);
200 auto y = torch::sqrt(u0) * torch::sin(2.0 * M_PI * u2);
201 auto z = torch::sqrt(u0) * torch::cos(2.0 * M_PI * u2);
203 auto quats = Quaternion(torch::stack({w, x, y,
z}, 1), 1);
205 return fill_matrix(quats.to_R2());
209Orientation::expand_as_needed(
const Rot & input,
unsigned int inp_size)
const
212 return input.batch_expand({inp_size});
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:59
const std::string & doc() const
A readonly reference to the option set's docstring.
Definition OptionSet.h:91
const T & get(const std::string &) const
Definition OptionSet.h:422
T & set(const std::string &)
Definition OptionSet.h:436
Create batch of rotations, with various methods.
Definition Orientation.h:37
Orientation(const OptionSet &options)
Construct a new Orientation object.
Definition Orientation.cxx:73
static OptionSet expected_options()
Definition Orientation.cxx:36
Rotation stored as modified Rodrigues parameters.
Definition Rot.h:49
Definition UserTensor.h:33
static OptionSet expected_options()
Definition UserTensor.cxx:30
constexpr Real z
Definition crystallography.h:40
constexpr Real a
Definition crystallography.h:36
constexpr Real b
Definition crystallography.h:37
Derived where(const torch::Tensor &condition, const Derived &a, const Derived &b)
Definition BatchTensorBase.h:396
Definition CrossRef.cxx:32
const torch::TensorOptions default_tensor_options()
Definition types.cxx:30
int64_t TorchSize
Definition types.h:35
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73