NEML2 1.4.0
Loading...
Searching...
No Matches
VariableStore.cxx
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#include "neml2/models/VariableStore.h"
26
27namespace neml2
28{
30 : _object(object),
31 _options(options),
32 _input_axis(declare_axis("input")),
33 _output_axis(declare_axis("output"))
34{
35}
36
38VariableStore::declare_axis(const std::string & name)
39{
40 neml_assert(!_axes.has_key(name),
41 "Trying to declare an axis named ",
42 name,
43 ", but an axis with the same name already exists.");
44
45 auto axis = std::make_unique<LabeledAxis>();
46 return *_axes.set_pointer(name, std::move(axis));
47}
48
49void
55
58{
59 return _input_views.query_value(name);
60}
61
64{
65 return _output_views.query_value(name);
66}
67
68void
70{
71 for (auto && [name, var] : input_views())
72 var.cache(batch_shape);
73 for (auto && [name, var] : output_views())
74 var.cache(batch_shape);
75}
76
77void
79 const torch::TensorOptions & options,
80 bool in,
81 bool out,
82 bool dout_din,
83 bool d2out_din2)
84{
85 // Allocate input storage only if this is a host model
86 if (in && _object->host() == _object)
87 _in = LabeledVector::zeros(batch_shape, {&input_axis()}, options);
88
89 // Allocate output storage
90 if (out)
91 _out = LabeledVector::zeros(batch_shape, {&output_axis()}, options);
92
93 if (dout_din)
94 _dout_din = LabeledMatrix::zeros(batch_shape, {&output_axis(), &input_axis()}, options);
95
96 if (d2out_din2)
97 _d2out_din2 = LabeledTensor3D::zeros(
98 batch_shape, {&output_axis(), &input_axis(), &input_axis()}, options);
99
100 if (in)
102
104}
105
106void
108{
109 for (auto && [name, var] : input_views())
110 var.setup_views(&_object->host<VariableStore>()->input_storage());
111}
112
113void
115{
116 for (auto && [name, var] : output_views())
117 var.setup_views(&_out, &_dout_din, &_d2out_din2);
118}
119
120void
122{
123 for (auto && [name, var] : input_views())
124 var.reinit_views(true, false, false);
125}
126
127void
129{
130 for (auto && [name, var] : output_views())
131 var.reinit_views(out, dout_din, d2out_din2);
132}
133
134void
136{
137 bool out_detached = false;
138 bool dout_din_detached = false;
139 bool d2out_din2_detached = false;
140
141 // Detach and zero per request
142 if (out)
143 {
144 if (_out.tensor().requires_grad())
145 {
146 _out.tensor().detach_();
147 out_detached = true;
148 }
149 }
150
151 if (dout_din)
152 {
153 if (_dout_din.tensor().requires_grad())
154 {
155 _dout_din.tensor().detach_();
157 }
158 _dout_din.zero_();
159 }
160
161 if (d2out_din2)
162 {
163 if (_d2out_din2.tensor().requires_grad())
164 {
165 _d2out_din2.tensor().detach_();
166 d2out_din2_detached = true;
167 }
168 _d2out_din2.zero_();
169 }
170
171 // If the storage is detached in-place, we need to reconfigure all the views.
172 reinit_output_views(out_detached, dout_din_detached, d2out_din2_detached);
173}
174} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:44
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
void setup_layout()
Definition LabeledAxis.cxx:158
void zero_()
Zero out this tensor.
Definition LabeledTensor.cxx:152
static LabeledVector zeros(TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new storage with zeros.
Definition LabeledTensor.cxx:109
const BatchTensor & tensor() const
Definition LabeledTensor.h:107
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:38
const T * host() const
Get a readonly pointer to the host.
Definition NEML2Object.h:90
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:59
Definition Variable.h:41
Definition VariableStore.h:37
virtual void setup_input_views()
Tell each input variable view which tensor storage(s) to view into.
Definition VariableStore.cxx:107
Storage< VariableName, VariableBase > & output_views()
Definition VariableStore.h:109
virtual void reinit_output_views(bool out, bool dout_din=true, bool d2out_din2=true)
Create the views for output variables, and optionally for the derivative and second derivatives.
Definition VariableStore.cxx:128
LabeledAxis & output_axis()
Definition VariableStore.h:97
VariableBase * input_view(const VariableName &)
Get the view of an input variable.
Definition VariableStore.cxx:57
virtual void setup_output_views()
Tell each output variable view which tensor storage(s) to view into.
Definition VariableStore.cxx:114
virtual void allocate_variables(TorchShapeRef batch_shape, const torch::TensorOptions &options, bool in, bool out, bool dout_din, bool d2out_din2)
Allocate variable storages given the batch shape and tensor options.
Definition VariableStore.cxx:78
virtual void detach_and_zero(bool out, bool dout_din=true, bool d2out_din2=true)
Detach the tensor storages and set each element in the tensor to 0.
Definition VariableStore.cxx:135
Storage< VariableName, VariableBase > & input_views()
Definition VariableStore.h:103
virtual void setup_layout()
Setup the layouts of all the registered axes.
Definition VariableStore.cxx:50
LabeledAxis & declare_axis(const std::string &name)
Definition VariableStore.cxx:38
VariableBase * output_view(const VariableName &)
Get the view of an output variable.
Definition VariableStore.cxx:63
LabeledVector & input_storage()
Definition VariableStore.h:115
virtual void cache(TorchShapeRef batch_shape)
Cache the variable's batch shape.
Definition VariableStore.cxx:69
LabeledAxis & input_axis()
Definition VariableStore.h:91
virtual void reinit_input_views()
Create the views for input variables.
Definition VariableStore.cxx:121
VariableStore(const OptionSet &options, NEML2Object *object)
Definition VariableStore.cxx:29
Definition CrossRef.cxx:32
torch::IntArrayRef TorchShapeRef
Definition types.h:37
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73