NEML2 1.4.0
Loading...
Searching...
No Matches
VecBase.h
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/tensors/FixedDimTensor.h"
28#include "neml2/tensors/Scalar.h"
29#include "neml2/tensors/R2.h"
30
31namespace neml2
32{
33class Rot;
34class R3;
35
41template <class Derived>
42class VecBase : public FixedDimTensor<Derived, 3>
43{
44public:
46
47 [[nodiscard]] static Derived
48 fill(const Real & v1,
49 const Real & v2,
50 const Real & v3,
51 const torch::TensorOptions & options = default_tensor_options());
52
53 [[nodiscard]] static Derived fill(const Scalar & v1, const Scalar & v2, const Scalar & v3);
54
56 [[nodiscard]] static R2
57 identity_map(const torch::TensorOptions & options = default_tensor_options());
58
61
63 template <class Derived2>
64 Scalar dot(const VecBase<Derived2> & v) const;
65
67 template <class Derived2>
69
71 template <class Derived2>
72 R2 outer(const VecBase<Derived2> & v) const;
73
75 Scalar norm_sq() const;
76
78 Scalar norm() const;
79
81 Derived rotate(const Rot & r) const;
82
84 Derived rotate(const R2 & R) const;
85
87 R2 drotate(const Rot & r) const;
88
90 R3 drotate(const R2 & R) const;
91};
92
93template <class Derived>
94template <class Derived2>
97{
99 auto res = torch::linalg_vecdot(*this, v);
100 return Scalar(res, res.dim());
101}
102
103template <class Derived>
104template <class Derived2>
107{
109
110 auto batch_dim = broadcast_batch_dim(*this, v);
111 auto pair = torch::broadcast_tensors({*this, v});
112
113 return Derived(torch::linalg_cross(pair[0], pair[1]), batch_dim);
114}
115
116template <class Derived>
117template <class Derived2>
118R2
120{
121 return torch::matmul(this->unsqueeze(-1), v.unsqueeze(-2));
122}
123} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
CrossRef()=default
FixedDimTensor inherits from BatchTensorBase and additionally templates on the base shape.
Definition FixedDimTensor.h:38
FixedDimTensor()=default
Default constructor.
A basic R2.
Definition R2.h:42
The (logical) full third order tensor.
Definition R3.h:41
Rotation stored as modified Rodrigues parameters.
Definition Rot.h:49
The (logical) scalar.
Definition Scalar.h:38
Base class for the (logical) vector.
Definition VecBase.h:43
Derived rotate(const Rot &r) const
Rotate using a Rodrigues vector.
Definition VecBase.cxx:78
R2 outer(const VecBase< Derived2 > &v) const
outer product
Definition VecBase.h:119
Scalar operator()(TorchSize i) const
Accessor.
Definition VecBase.cxx:57
Derived cross(const VecBase< Derived2 > &v) const
cross product
Definition VecBase.h:106
static R2 identity_map(const torch::TensorOptions &options=default_tensor_options())
The derivative of a vector with respect to itself.
Definition VecBase.cxx:50
Scalar norm_sq() const
Norm squared.
Definition VecBase.cxx:64
Scalar dot(const VecBase< Derived2 > &v) const
dot product
Definition VecBase.h:96
static Derived fill(const Real &v1, const Real &v2, const Real &v3, const torch::TensorOptions &options=default_tensor_options())
Definition VecBase.cxx:33
Scalar norm() const
Norm.
Definition VecBase.cxx:71
R2 drotate(const Rot &r) const
Derivative of the rotated vector w.r.t. the Rodrigues vector.
Definition VecBase.cxx:92
Definition CrossRef.cxx:32
const torch::TensorOptions default_tensor_options()
Definition types.cxx:30
int64_t TorchSize
Definition types.h:35
TorchSize broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
double Real
Definition types.h:33
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.