NEML2 1.4.0
Loading...
Searching...
No Matches
Variable.cxx
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#include "neml2/tensors/Variable.h"
26
27namespace neml2
28{
29void
34
35void
37 const LabeledMatrix * deriv,
39{
40 if (value)
41 _value_storage = value;
42
43 if (deriv)
45
46 if (secderiv)
48}
49
50void
52{
53 neml_assert(other, "other != nullptr");
54 setup_views(other->_value_storage, other->_derivative_storage, other->_second_derivative_storage);
55}
56
57void
59{
60 if (out)
61 neml_assert(_value_storage, "Variable value storage not initialized.");
62 if (dout_din)
63 neml_assert(_derivative_storage, "Variable derivative storage not initialized.");
64 if (d2out_din2)
65 neml_assert(_second_derivative_storage, "Variable second derivative storage not initialized.");
66
67 if (out)
68 _raw_value = (*_value_storage)(name());
69
70 if (dout_din)
71 for (const auto & arg : args())
73
74 if (d2out_din2)
75 for (const auto & arg1 : args())
76 for (const auto & arg2 : args())
78}
79
80const LabeledVector &
82{
83 neml_assert_dbg(_value_storage, "Variable value storage not initialized.");
84 return *_value_storage;
85}
86
87const LabeledMatrix &
89{
90 neml_assert_dbg(_derivative_storage, "Variable derivative storage not initialized.");
91 return *_derivative_storage;
92}
93
94const LabeledTensor3D &
96{
97 neml_assert_dbg(_second_derivative_storage, "Variable 2nd derivative storage not initialized.");
99}
100
103{
104 neml_assert_dbg(_dvalue_d.count(x.name()),
105 "Error retrieving first derivative: ",
106 name(),
107 " does not depend on ",
108 x.name());
109
110 return Derivative(_dvalue_d[x.name()]);
111}
112
115{
116 neml_assert_dbg(_d2value_d.count(x1.name()),
117 "Error retrieving second derivative: ",
118 name(),
119 " does not depend on ",
120 x1.name());
121 neml_assert_dbg(_d2value_d[x1.name()].count(x2.name()),
122 "Error retrieving second derivative: d(",
123 name(),
124 ")/d(",
125 x1.name(),
126 ") does not depend on ",
127 x2.name());
128
129 return Derivative(_d2value_d[x1.name()][x2.name()]);
130}
131
132void
134{
135 _value.index_put_({torch::indexing::Slice()},
136 val.batch_expand_as(_value).base_reshape(_value.base_sizes()));
137}
138}
TorchShapeRef base_sizes() const
Return the base size.
Definition BatchTensorBase.cxx:163
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
Definition Variable.h:242
void operator=(const BatchTensor &val)
Definition Variable.cxx:133
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
A single-batched, logically 3D LabeledTensor.
Definition LabeledTensor3D.h:38
A single-batched, logically 1D LabeledTensor.
Definition LabeledVector.h:38
Definition Variable.h:41
BatchTensor _raw_value
The raw (flattened) variable value.
Definition Variable.h:129
const LabeledVector * _value_storage
The value storage that this variable is viewing into.
Definition Variable.h:138
std::map< VariableName, BatchTensor > _dvalue_d
The derivative of this variable w.r.t. arguments.
Definition Variable.h:132
Derivative d(const VariableBase &x)
Create a wrapper representing the derivative dy/dx.
Definition Variable.cxx:102
const std::vector< VariableName > & args() const
Arguments.
Definition Variable.h:71
const LabeledMatrix * _derivative_storage
The derivative storage that this variable is viewing into.
Definition Variable.h:141
const LabeledTensor3D & second_derivative_storage() const
Definition Variable.cxx:95
const LabeledVector & value_storage() const
Accessors for storage.
Definition Variable.cxx:81
TorchShape _batch_sizes
Batch shape of this variable.
Definition Variable.h:123
void setup_views(const LabeledVector *value, const LabeledMatrix *deriv=nullptr, const LabeledTensor3D *secderiv=nullptr)
Setup the variable's views into blocks of the storage.
Definition Variable.cxx:36
const LabeledTensor3D * _second_derivative_storage
The second derivative storage that this variable is viewing into.
Definition Variable.h:144
const LabeledMatrix & derivative_storage() const
Definition Variable.cxx:88
const VariableName & name() const
Name of this variable.
Definition Variable.h:98
virtual void cache(TorchShapeRef batch_shape)
Cache the variable's batch shape.
Definition Variable.cxx:30
virtual void reinit_views(bool out, bool dout_din, bool d2out_din2)
Reinitialize variable views.
Definition Variable.cxx:58
std::map< VariableName, std::map< VariableName, BatchTensor > > _d2value_d
The second derivative of this variable w.r.t. arguments.
Definition Variable.h:135
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
torch::IntArrayRef TorchShapeRef
Definition types.h:35
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73