NEML2 1.4.0
Loading...
Searching...
No Matches
Variable.h
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/tensors/LabeledAxisAccessor.h"
28#include "neml2/tensors/tensors.h"
29#include "neml2/tensors/LabeledVector.h"
30#include "neml2/tensors/LabeledMatrix.h"
31#include "neml2/tensors/LabeledTensor3D.h"
32
33namespace neml2
34{
36
37// Forward declarations
38class Derivative;
39
41{
42public:
50
51 virtual ~VariableBase() = default;
52
54 virtual void cache(TorchShapeRef batch_shape);
55
57 void setup_views(const LabeledVector * value,
58 const LabeledMatrix * deriv = nullptr,
59 const LabeledTensor3D * secderiv = nullptr);
60
62 void setup_views(const VariableBase * other);
63
65 virtual void reinit_views(bool out, bool dout_din, bool d2out_din2);
66
68 virtual void requires_grad_(bool req = true) = 0;
69
71 const std::vector<VariableName> & args() const { return _args; }
72
74 void add_arg(const VariableBase & arg) { _args.push_back(arg.name()); }
75
77 void clear_args() { _args.clear(); }
78
80 Derivative d(const VariableBase & x);
81
83 Derivative d(const VariableBase & x1, const VariableBase & x2);
84
86 const LabeledVector & value_storage() const;
87 const LabeledMatrix & derivative_storage() const;
90
92 const BatchTensor & raw_value() const { return _raw_value; }
93
95 virtual const BatchTensor tensor() const = 0;
96
98 const VariableName & name() const { return _name; }
99
102
104 virtual TorchShapeRef base_sizes() const = 0;
105
107 TorchSize batch_dim() const { return _batch_sizes.size(); }
108
110 TorchSize base_dim() const { return base_sizes().size(); }
111
114
116 virtual TorchShapeRef sizes() const = 0;
117
118protected:
121
124
126 std::vector<VariableName> _args;
127
130
132 std::map<VariableName, BatchTensor> _dvalue_d;
133
135 std::map<VariableName, std::map<VariableName, BatchTensor>> _d2value_d;
136
139
142
145};
146
151template <typename T>
152class Variable : public VariableBase
153{
154public:
155 template <typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<BatchTensor, T2>>>
158 _base_sizes(T::const_base_sizes)
159 {
160 }
161
162 template <typename T2 = T, typename = typename std::enable_if_t<std::is_same_v<BatchTensor, T2>>>
168
169 virtual void reinit_views(bool out, bool dout_din, bool d2out_din2) override
170 {
172 if (out)
173 _value = T(_raw_value.view(sizes()), batch_dim());
174 }
175
176 virtual void requires_grad_(bool req = true) override { _value.requires_grad_(req); }
177
178 virtual TorchShapeRef base_sizes() const override { return _base_sizes; }
179
180 virtual TorchShapeRef sizes() const override { return _sizes; }
181
183 [[deprecated("Variable<T> must be assigned to references -- missing &")]] Variable(
184 const Variable<T> &)
185 {
186 }
187
189 [[deprecated("Variable<T> must be assigned to references -- missing &")]] void
191 {
192 }
193
201 {
202 _value.index_put_({torch::indexing::Slice()},
203 val.batch_expand(batch_sizes()).base_reshape(base_sizes()));
204 }
205
207 const T & value() const { return _value; }
208
210 virtual const BatchTensor tensor() const override { return _value; }
211
213 T operator-() const { return -_value; }
214
215 operator T() const { return _value; }
216
223
224 template <typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<T2, BatchTensor>>>
225 operator BatchTensor() const
226 {
227 return _value;
228 }
229
230protected:
233
236
239};
240
242{
243public:
245 : _value(val)
246 {
247 }
248
249 const BatchTensor & value() const { return _value; }
250
251 void operator=(const BatchTensor & val);
252
253private:
254 BatchTensor & _value;
255};
256
257// Everything below is just for convenience: We just forward operations to the the variable values
258// so that we can do
259//
260// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
261// var4 = (var1 - var2) * var3
262// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
263//
264// instead of the (ugly?) expression below
265//
266// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
267// var4 = (var1.v - var2.v) * var3.v
268// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
269#define FWD_VARIABLE_BINARY_OP(op) \
270 template <typename T1, \
271 typename T2, \
272 typename = typename std::enable_if_t<std::is_base_of_v<VariableBase, T1> || \
273 std::is_base_of_v<VariableBase, T2>>> \
274 auto op(const T1 & a, const T2 & b) \
275 { \
276 if constexpr (std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
277 return op(a.value(), b.value()); \
278 \
279 if constexpr (std::is_base_of_v<VariableBase, T1> && !std::is_base_of_v<VariableBase, T2>) \
280 return op(a.value(), b); \
281 \
282 if constexpr (!std::is_base_of_v<VariableBase, T1> && std::is_base_of_v<VariableBase, T2>) \
283 return op(a, b.value()); \
284 } \
285 static_assert(true)
286FWD_VARIABLE_BINARY_OP(operator+);
287FWD_VARIABLE_BINARY_OP(operator-);
288FWD_VARIABLE_BINARY_OP(operator*);
289FWD_VARIABLE_BINARY_OP(operator/);
290}
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
Definition Variable.h:242
void operator=(const BatchTensor &val)
Definition Variable.cxx:133
Derivative(BatchTensor &val)
Definition Variable.h:244
const BatchTensor & value() const
Definition Variable.h:249
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:44
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
A single-batched, logically 3D LabeledTensor.
Definition LabeledTensor3D.h:38
A single-batched, logically 1D LabeledTensor.
Definition LabeledVector.h:38
Definition Variable.h:41
BatchTensor _raw_value
The raw (flattened) variable value.
Definition Variable.h:129
VariableBase(const VariableName &name_in)
Definition Variable.h:43
virtual TorchShapeRef base_sizes() const =0
Base shape.
const LabeledVector * _value_storage
The value storage that this variable is viewing into.
Definition Variable.h:138
std::map< VariableName, BatchTensor > _dvalue_d
The derivative of this variable w.r.t. arguments.
Definition Variable.h:132
Derivative d(const VariableBase &x)
Create a wrapper representing the derivative dy/dx.
Definition Variable.cxx:102
const std::vector< VariableName > & args() const
Arguments.
Definition Variable.h:71
virtual ~VariableBase()=default
TorchSize base_storage() const
Base storage.
Definition Variable.h:113
TorchShapeRef batch_sizes() const
Batch shape.
Definition Variable.h:101
const LabeledMatrix * _derivative_storage
The derivative storage that this variable is viewing into.
Definition Variable.h:141
const LabeledTensor3D & second_derivative_storage() const
Definition Variable.cxx:95
const LabeledVector & value_storage() const
Accessors for storage.
Definition Variable.cxx:81
std::vector< VariableName > _args
Names of the variables that this variable depends on.
Definition Variable.h:126
TorchSize batch_dim() const
Batch dimension.
Definition Variable.h:107
void add_arg(const VariableBase &arg)
Add an argument.
Definition Variable.h:74
virtual TorchShapeRef sizes() const =0
Total shape.
TorchShape _batch_sizes
Batch shape of this variable.
Definition Variable.h:123
virtual const BatchTensor tensor() const =0
Variable value of the logical shape.
void setup_views(const LabeledVector *value, const LabeledMatrix *deriv=nullptr, const LabeledTensor3D *secderiv=nullptr)
Setup the variable's views into blocks of the storage.
Definition Variable.cxx:36
const LabeledTensor3D * _second_derivative_storage
The second derivative storage that this variable is viewing into.
Definition Variable.h:144
const LabeledMatrix & derivative_storage() const
Definition Variable.cxx:88
const VariableName & name() const
Name of this variable.
Definition Variable.h:98
TorchSize base_dim() const
Base dimension.
Definition Variable.h:110
const VariableName _name
Name of the variable.
Definition Variable.h:120
const BatchTensor & raw_value() const
Raw flattened variable value.
Definition Variable.h:92
void clear_args()
Clear arguments.
Definition Variable.h:77
virtual void cache(TorchShapeRef batch_shape)
Cache the variable's batch shape.
Definition Variable.cxx:30
virtual void requires_grad_(bool req=true)=0
Set requires_grad for the underlying storage.
virtual void reinit_views(bool out, bool dout_din, bool d2out_din2)
Reinitialize variable views.
Definition Variable.cxx:58
std::map< VariableName, std::map< VariableName, BatchTensor > > _d2value_d
The second derivative of this variable w.r.t. arguments.
Definition Variable.h:135
Concrete definition of a variable.
Definition Variable.h:153
const TorchShape _base_sizes
Base shape of this variable.
Definition Variable.h:232
virtual void requires_grad_(bool req=true) override
Set requires_grad for the underlying storage.
Definition Variable.h:176
virtual const BatchTensor tensor() const override
Variable value of the logical shape.
Definition Variable.h:210
void operator=(const BatchTensor &val)
Set the raw value to store val.
Definition Variable.h:200
virtual void reinit_views(bool out, bool dout_din, bool d2out_din2) override
Reinitialize variable views.
Definition Variable.h:169
void operator=(const Variable< T > &)
Suppressed assignment operator to prevent accidental dereferencing.
Definition Variable.h:190
virtual void cache(TorchShapeRef batch_shape) override
Set the batch shape and base shape according to val.
Definition Variable.h:218
Variable(const VariableName &name_in, TorchShapeRef base_shape)
Definition Variable.h:163
T operator-() const
Negation.
Definition Variable.h:213
virtual TorchShapeRef sizes() const override
Total shape.
Definition Variable.h:180
const T & value() const
Variable value of the logical shape.
Definition Variable.h:207
virtual TorchShapeRef base_sizes() const override
Base shape.
Definition Variable.h:178
Variable(const VariableName &name_in)
Definition Variable.h:156
T _value
Variable value of the logical shape.
Definition Variable.h:238
TorchShape _sizes
Shape of this variable.
Definition Variable.h:235
Variable(const Variable< T > &)
Suppressed constructor to prevent accidental dereferencing.
Definition Variable.h:183
TorchSize storage_size(TorchShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:32
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
std::vector< TorchSize > TorchShape
Definition types.h:34
torch::IntArrayRef TorchShapeRef
Definition types.h:35