NEML2 1.4.0
Loading...
Searching...
No Matches
Variable.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/tensors/LabeledAxisAccessor.h"
28#include "neml2/tensors/tensors.h"
29#include "neml2/tensors/LabeledVector.h"
30#include "neml2/tensors/LabeledMatrix.h"
31#include "neml2/tensors/LabeledTensor3D.h"
32
33namespace neml2
34{
35using VariableName = LabeledAxisAccessor;
36
37// Forward declarations
38class Derivative;
39class Model;
40
42{
43public:
45
46 virtual ~VariableBase() = default;
47
49 virtual void cache(TensorShapeRef batch_shape);
50
52 virtual void setup_views(const LabeledVector * value,
53 const LabeledMatrix * deriv = nullptr,
54 const LabeledTensor3D * secderiv = nullptr);
55
57 virtual void setup_views(const VariableBase * other);
58
60 virtual void requires_grad_(bool req = true) = 0;
61
63 Derivative d(const VariableBase & x);
64
66 Derivative d(const VariableBase & x1, const VariableBase & x2);
67
69 const Tensor & raw_value() const { return _raw_value; }
70
72 virtual const Tensor tensor() const = 0;
73
75 const VariableName & name() const { return _name; }
76
78 const Model & owner() const { return *_owner; }
79
81 const VariableBase * src() const { return _src; }
82
85
87 virtual TensorShapeRef base_sizes() const = 0;
88
90 Size batch_dim() const { return _batch_sizes.size(); }
91
93 Size base_dim() const { return base_sizes().size(); }
94
97
99 virtual TensorShapeRef sizes() const = 0;
100
102 virtual TensorType type() const = 0;
103
106 bool is_state() const { return _is_state; }
107 bool is_old_state() const { return _is_old_state; }
108 bool is_force() const { return _is_force; }
109 bool is_old_force() const { return _is_old_force; }
110 bool is_residual() const { return _is_residual; }
111 bool is_parameter() const { return _is_parameter; }
112 bool is_other() const { return _is_other; }
115
117 // Note that the check depends on whether we are currently solving nonlinear system
118 bool is_dependent() const;
119
120protected:
123
125 const Model * _owner;
126
129
132
134 std::map<VariableName, Tensor> _dvalue_d;
135
137 std::map<VariableName, std::map<VariableName, Tensor>> _d2value_d;
138
141
144 const bool _is_state;
145 const bool _is_old_state;
146 const bool _is_force;
147 const bool _is_old_force;
148 const bool _is_residual;
149 const bool _is_parameter;
150 const bool _is_other;
153};
154
159template <typename T>
160class Variable : public VariableBase
161{
162public:
163 template <typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<Tensor, T2>>>
165 const Model * owner,
168 _type(type),
169 _base_sizes(T::const_base_sizes)
170 {
171 }
172
173 template <typename T2 = T, typename = typename std::enable_if_t<std::is_same_v<Tensor, T2>>>
175 const Model * owner,
177 TensorType type = TensorType::kTensor)
179 _type(type),
181 {
182 }
183
184 virtual void setup_views(const LabeledVector * value,
185 const LabeledMatrix * deriv = nullptr,
186 const LabeledTensor3D * secderiv = nullptr) override
187 {
189 if (value)
190 _value = T(_raw_value.view(sizes()), batch_dim());
191 }
192
193 virtual void setup_views(const VariableBase * other) override
194 {
196 _value = T(_raw_value.view(sizes()), batch_dim());
197 }
198
199 virtual void requires_grad_(bool req = true) override { _value.requires_grad_(req); }
200
201 virtual TensorShapeRef base_sizes() const override { return _base_sizes; }
202
203 virtual TensorShapeRef sizes() const override { return _sizes; }
204
206 [[deprecated("Variable<T> must be assigned to references -- missing &")]] Variable(
207 const Variable<T> &)
208 {
209 }
210
212 [[deprecated("Variable<T> must be assigned to references -- missing &")]] void
214 {
215 }
216
223 void operator=(const Tensor & val)
224 {
225 _value.index_put_({torch::indexing::Slice()},
226 val.batch_expand(batch_sizes()).base_reshape(base_sizes()));
227 }
228
230 const T & value() const { return _value; }
231
233 virtual const Tensor tensor() const override { return _value; }
234
235 virtual TensorType type() const override { return _type; }
236
238 T operator-() const { return -_value; }
239
240 operator T() const { return _value; }
241
248
249 template <typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<T2, Tensor>>>
250 operator Tensor() const
251 {
252 return _value;
253 }
254
255protected:
258
261
264
267};
268
270{
271public:
273 : _value(val)
274 {
275 }
276
277 const Tensor & value() const { return _value; }
278
279 Derivative & operator=(const Tensor & val);
280
281private:
282 Tensor & _value;
283};
284
285// Everything below is just for convenience: We just forward operations to the the variable values
286// so that we can do
287//
288// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
289// var4 = (var1 - var2) * var3
290// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
291//
292// instead of the (ugly?) expression below
293//
294// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
295// var4 = (var1.v - var2.v) * var3.v
296// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
297#define FWD_VARIABLE_BINARY_OP(op) \
298 template <typename T1, \
299 typename T2, \
300 typename = typename std::enable_if_t<std::is_base_of_v<VariableBase, T1> || \
301 std::is_base_of_v<VariableBase, T2>>> \
302 auto op(const T1 & a, const T2 & b) \
303 { \
304 if constexpr (std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
305 return op(a.value(), b.value()); \
306 \
307 if constexpr (std::is_base_of_v<VariableBase, T1> && !std::is_base_of_v<VariableBase, T2>) \
308 return op(a.value(), b); \
309 \
310 if constexpr (!std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
311 return op(a, b.value()); \
312 } \
313 static_assert(true)
314FWD_VARIABLE_BINARY_OP(operator+);
315FWD_VARIABLE_BINARY_OP(operator-);
316FWD_VARIABLE_BINARY_OP(operator*);
317FWD_VARIABLE_BINARY_OP(operator/);
318}
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
Definition Variable.h:270
Derivative & operator=(const Tensor &val)
Definition Variable.cxx:140
Derivative(Tensor &val)
Definition Variable.h:272
const Tensor & value() const
Definition Variable.h:277
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:47
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
A single-batched, logically 3D LabeledTensor.
Definition LabeledTensor3D.h:38
A single-batched, logically 1D LabeledTensor.
Definition LabeledVector.h:38
The base class for all constitutive models.
Definition Model.h:55
Definition Tensor.h:32
Definition Variable.h:42
Derivative d(const VariableBase &x)
Create a wrapper representing the derivative dy/dx.
Definition Variable.cxx:91
virtual ~VariableBase()=default
bool is_old_force() const
Definition Variable.h:109
const Model * _owner
The model which declared this variable.
Definition Variable.h:125
bool is_parameter() const
Definition Variable.h:111
bool is_state() const
Definition Variable.h:106
const Model & owner() const
The owner of this variable.
Definition Variable.h:78
const bool _is_residual
Definition Variable.h:148
const bool _is_state
Definition Variable.h:144
Size base_storage() const
Base storage.
Definition Variable.h:96
const VariableBase * src() const
The source variable.
Definition Variable.h:81
TensorShape _batch_sizes
Batch shape of this variable.
Definition Variable.h:128
Size batch_dim() const
Batch dimension.
Definition Variable.h:90
const Tensor & raw_value() const
Raw flattened variable value.
Definition Variable.h:69
TensorShapeRef batch_sizes() const
Batch shape.
Definition Variable.h:84
Tensor _raw_value
The raw (flattened) variable value.
Definition Variable.h:131
std::map< VariableName, std::map< VariableName, Tensor > > _d2value_d
The second derivative of this variable w.r.t. arguments.
Definition Variable.h:137
VariableBase(const VariableName &name_in, const Model *owner)
Definition Variable.cxx:30
const bool _is_old_force
Definition Variable.h:147
const bool _is_other
Definition Variable.h:150
bool is_solve_dependent() const
Definition Variable.h:113
const bool _is_force
Definition Variable.h:146
bool is_residual() const
Definition Variable.h:110
virtual void setup_views(const LabeledVector *value, const LabeledMatrix *deriv=nullptr, const LabeledTensor3D *secderiv=nullptr)
Setup the variable's views into blocks of the storage.
Definition Variable.cxx:53
virtual TensorShapeRef sizes() const =0
Total shape.
const bool _is_solve_dependent
Definition Variable.h:151
const VariableName & name() const
Name of this variable.
Definition Variable.h:75
const VariableBase * _src
The source variable this variable follows.
Definition Variable.h:140
const bool _is_parameter
Definition Variable.h:149
virtual const Tensor tensor() const =0
Variable value of the logical shape.
bool is_dependent() const
Check if the derivative with respect to this variable should be evaluated.
Definition Variable.cxx:85
const bool _is_old_state
Definition Variable.h:145
const VariableName _name
Name of the variable.
Definition Variable.h:122
virtual TensorShapeRef base_sizes() const =0
Base shape.
Size base_dim() const
Base dimension.
Definition Variable.h:93
bool is_force() const
Definition Variable.h:108
bool is_other() const
Definition Variable.h:112
virtual TensorType type() const =0
Variable type.
std::map< VariableName, Tensor > _dvalue_d
The derivative of this variable w.r.t. arguments.
Definition Variable.h:134
bool is_old_state() const
Definition Variable.h:107
virtual void cache(TensorShapeRef batch_shape)
Cache the variable's batch shape.
Definition Variable.cxx:47
virtual void requires_grad_(bool req=true)=0
Set requires_grad for the underlying storage.
Concrete definition of a variable.
Definition Variable.h:161
void operator=(const Tensor &val)
Set the raw value to store val.
Definition Variable.h:223
virtual void setup_views(const VariableBase *other) override
Setup the variable's views following another variable.
Definition Variable.h:193
virtual void cache(TensorShapeRef batch_shape) override
Set the batch shape and base shape according to val.
Definition Variable.h:243
virtual void requires_grad_(bool req=true) override
Set requires_grad for the underlying storage.
Definition Variable.h:199
Variable(const VariableName &name_in, const Model *owner, TensorShapeRef base_shape, TensorType type=TensorType::kTensor)
Definition Variable.h:174
const TensorType _type
Variable tensor type.
Definition Variable.h:257
virtual void setup_views(const LabeledVector *value, const LabeledMatrix *deriv=nullptr, const LabeledTensor3D *secderiv=nullptr) override
Setup the variable's views into blocks of the storage.
Definition Variable.h:184
TensorShape _sizes
Shape of this variable.
Definition Variable.h:263
void operator=(const Variable< T > &)
Suppressed assignment operator to prevent accidental dereferencing.
Definition Variable.h:213
T operator-() const
Negation.
Definition Variable.h:238
const TensorShape _base_sizes
Base shape of this variable.
Definition Variable.h:260
virtual const Tensor tensor() const override
Variable value of the logical shape.
Definition Variable.h:233
virtual TensorShapeRef base_sizes() const override
Base shape.
Definition Variable.h:201
const T & value() const
Variable value of the logical shape.
Definition Variable.h:230
Variable(const VariableName &name_in, const Model *owner, TensorType type=TensorTypeEnum< T2 >::value)
Definition Variable.h:164
virtual TensorShapeRef sizes() const override
Total shape.
Definition Variable.h:203
T _value
Variable value of the logical shape.
Definition Variable.h:266
virtual TensorType type() const override
Variable type.
Definition Variable.h:235
Variable(const Variable< T > &)
Suppressed constructor to prevent accidental dereferencing.
Definition Variable.h:206
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:40
TensorShape add_shapes(S &&... shape)
Definition utils.h:298
Definition CrossRef.cxx:30
torch::SmallVector< Size > TensorShape
Definition types.h:34
LabeledAxisAccessor VariableName
Definition parser_utils.h:33
TensorType
Definition tensors.h:57
torch::IntArrayRef TensorShapeRef
Definition types.h:35
Definition tensors.h:64