NEML2 1.4.0
Loading...
Searching...
No Matches
LabeledAxis.cxx
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#include "neml2/tensors/LabeledAxis.h"
26
27namespace neml2
28{
30 : _offset(0)
31{
32}
33
35 : _variables(other._variables),
36 _subaxes(other._subaxes),
37 _layout(other._layout),
38 _offset(other._offset)
39{
40}
41
44{
45 add(*this, sz, accessor.vec().begin(), accessor.vec().end());
46 return *this;
47}
48
49void
52 const std::vector<std::string>::const_iterator & cur,
53 const std::vector<std::string>::const_iterator & end) const
54{
55 if (cur == end - 1)
56 {
57 if (!axis.has_variable(*cur))
58 axis._variables.emplace(*cur, sz);
59 }
60 else
61 {
62 axis.add<LabeledAxis>(*cur);
63 add(axis.subaxis(*cur), sz, cur + 1, end);
64 }
65}
66
67LabeledAxis &
68LabeledAxis::rename(const std::string & original, const std::string & rename)
69{
70 // This could be a variable name
71 auto var = _variables.find(original);
72 if (var != _variables.end())
73 {
74 auto sz = var->second;
75 _variables.erase(var);
76 _variables.emplace(rename, sz);
77 return *this;
78 }
79
80 // or a sub-axis name
81 auto subaxis = _subaxes.find(original);
82 if (subaxis != _subaxes.end())
83 {
84 auto axis = subaxis->second;
85 _subaxes.erase(subaxis);
86 _subaxes.emplace(rename, axis);
87 return *this;
88 }
89
90 return *this;
91}
92
94LabeledAxis::remove(const std::string & name)
95{
96 // This could be a variable name
97 auto count = _variables.erase(name);
98 if (count)
99 return *this;
100
101 // or a sub-axis name
102 count += _subaxes.erase(name);
103
104 // If nothing has been removed, we should probably notify the user.
105 neml_assert_dbg(count, "Nothing removed in LabeledAxis::remove, did you mispell the name?");
106
107 return *this;
108}
109
112{
113 _variables.clear();
114 _subaxes.clear();
115 _layout.clear();
116 _offset = 0;
117
118 return *this;
119}
120
121std::vector<LabeledAxisAccessor>
123{
124 std::vector<LabeledAxisAccessor> merged_vars;
125 merge(other, {}, merged_vars);
126 return merged_vars;
127}
128
129void
131 std::vector<std::string> subaxes,
132 std::vector<LabeledAxisAccessor> & merged_vars)
133{
134 // First merge the variables
135 for (const auto & [name, sz] : other._variables)
136 if (!has_variable(name))
137 {
138 _variables.emplace(name, sz);
139 auto new_var = subaxes;
140 new_var.push_back(name);
141 merged_vars.push_back({new_var});
142 }
143
144 // Then merge the subaxes
145 for (auto & [name, subaxis] : other._subaxes)
146 {
147 auto found = _subaxes.find(name);
148 if (found == _subaxes.end())
149 _subaxes.emplace(name, std::make_shared<LabeledAxis>());
150
151 auto new_subaxes = subaxes;
152 new_subaxes.push_back(name);
153 _subaxes[name]->merge(*subaxis, new_subaxes, merged_vars);
154 }
155}
156
157void
159{
160 _offset = 0;
161 _layout.clear();
162
163 // First emplace all the variables
164 for (auto & [name, sz] : _variables)
165 {
166 std::pair<TorchSize, TorchSize> range = {_offset, _offset + sz};
167 _layout.emplace(name, range);
168 _offset += sz;
169 }
170
171 // Then subaxes
172 for (auto & [name, axis] : _subaxes)
173 {
174 // Setup the sub-axis if necessary
175 axis->setup_layout();
176 std::pair<TorchSize, TorchSize> range = {_offset, _offset + axis->storage_size()};
177 _layout.emplace(name, range);
178 _offset += axis->storage_size();
179 }
180}
181
182bool
184{
185 if (var.empty())
186 return false;
187
188 if (var.vec().size() > 1)
189 {
190 if (has_subaxis(var.vec()[0]))
191 return subaxis(var.vec()[0]).has_variable(var.slice(1));
192 else
193 return false;
194 }
195 else
196 return _variables.count(var.vec()[0]);
197}
198
199bool
201{
202 if (s.empty())
203 return false;
204
205 if (s.vec().size() > 1)
206 {
207 if (has_subaxis(s.vec()[0]))
208 return subaxis(s.vec()[0]).has_subaxis(s.slice(1));
209 else
210 return false;
211 }
212 else
213 return _subaxes.count(s.vec()[0]);
214}
215
218{
219 return storage_size(accessor.vec().begin(), accessor.vec().end());
220}
221
223LabeledAxis::storage_size(const std::vector<std::string>::const_iterator & cur,
224 const std::vector<std::string>::const_iterator & end) const
225{
226 if (cur == end - 1)
227 {
228 if (_variables.count(*cur))
229 return _variables.at(*cur);
230 else if (_subaxes.count(*cur))
231 return _subaxes.at(*cur)->storage_size();
232
233 neml_assert_dbg(false, "Trying to find the storage size of a non-existent item named ", *cur);
234 }
235
236 return subaxis(*cur).storage_size(cur + 1, end);
237}
238
241{
242 if (accessor.empty())
243 return torch::indexing::Slice();
244
245 return indices(0, accessor.vec().begin(), accessor.vec().end());
246}
247
250 const std::vector<std::string>::const_iterator & cur,
251 const std::vector<std::string>::const_iterator & end) const
252{
253 neml_assert_dbg(_layout.count(*cur), "Axis/variable named ", *cur, " does not exist.");
254 const auto & [rbegin, rend] = _layout.at(*cur);
255 if (cur == end - 1)
256 return torch::indexing::Slice(offset + rbegin, offset + rend);
257
258 return subaxis(*cur).indices(offset + rbegin, cur + 1, end);
259}
260
261std::vector<std::pair<TorchIndex, TorchIndex>>
263{
264 using namespace torch::indexing;
265
266 std::vector<std::pair<TorchIndex, TorchIndex>> indices;
267 std::vector<TorchSize> idxa;
268 std::vector<TorchSize> idxb;
270
271 if (idxa.empty())
272 return indices;
273
274 // We could be smart and merge contiguous indices
275 size_t i = 0;
276 size_t j = 1;
277 while (j < idxa.size() - 1)
278 {
279 if (idxa[j] == idxa[j + 1] && idxb[j] == idxb[j + 1])
280 j += 2;
281 else
282 {
283 indices.push_back({Slice(idxa[i], idxa[j]), Slice(idxb[i], idxb[j])});
284 i = j + 1;
285 j = i + 1;
286 }
287 }
288 indices.push_back({Slice(idxa[i], idxa[j]), Slice(idxb[i], idxb[j])});
289
290 return indices;
291}
292
293void
295 bool recursive,
296 std::vector<TorchSize> & idxa,
297 std::vector<TorchSize> & idxb,
299 TorchSize offsetb) const
300{
301 for (const auto & [name, sz] : _variables)
302 if (other.has_variable(name))
303 {
304 auto && [begina, enda] = _layout.at(name);
305 idxa.push_back(offseta + begina);
306 idxa.push_back(offseta + enda);
307 auto && [beginb, endb] = other._layout.at(name);
308 idxb.push_back(offsetb + beginb);
309 idxb.push_back(offsetb + endb);
310 }
311
312 if (recursive)
313 for (const auto & [name, axis] : _subaxes)
314 if (other.has_subaxis(name))
315 axis->common_indices(other.subaxis(name),
316 true,
317 idxa,
318 idxb,
319 offseta + _layout.at(name).first,
320 offsetb + other._layout.at(name).first);
321}
322
323std::vector<std::string>
325{
326 std::vector<std::string> names;
327 for (const auto & item : _layout)
328 names.push_back(item.first);
329 return names;
330}
331
332std::set<LabeledAxisAccessor>
334{
335 std::set<LabeledAxisAccessor> accessors;
337 return accessors;
338}
339
340void
341LabeledAxis::variable_accessors(std::set<LabeledAxisAccessor> & accessors,
343 bool recursive,
344 const LabeledAxisAccessor & subaxis) const
345{
346 for (auto & var : _variables)
347 {
349 var_accessor = var_accessor.on(cur);
350 if (subaxis.empty())
351 accessors.insert(var_accessor);
352 else if (var_accessor.slice(0, subaxis.size()) == subaxis)
353 accessors.insert(var_accessor);
354 }
355
356 if (recursive)
357 for (auto & [name, axis] : _subaxes)
358 {
359 auto next = cur.append(name);
360 axis->variable_accessors(accessors, next, recursive, subaxis);
361 }
362}
363
364const LabeledAxis &
365LabeledAxis::subaxis(const std::string & name) const
366{
368 _subaxes.count(name), "In LabeledAxis::subaxis, no subaxis matches given name ", name);
369
370 return *_subaxes.at(name);
371}
372
374LabeledAxis::subaxis(const std::string & name)
375{
377 _subaxes.count(name), "In LabeledAxis::subaxis, no subaxis matches given name ", name);
378
379 return *_subaxes.at(name);
380}
381
382bool
384{
385 // They must have the same size
386 if (_offset != other._offset)
387 return false;
388
389 // Comparing unordered maps is easy, two maps are equal if they have the same number of
390 // elements and the elements in one container are a permutation of the elements in the other
391 // container
392 if (_variables != other._variables)
393 return false;
394
395 // For subaxes, it's a little bit tricky as we need to compare the dereferenced axes.
396 for (auto & [name, axis] : _subaxes)
397 if (other._subaxes.count(name) == 0)
398 return false;
399 else if (*other._subaxes.at(name) != *axis)
400 return false;
401
402 return true;
403}
404
405std::ostream &
406operator<<(std::ostream & os, const LabeledAxis & axis)
407{
408 // Collect variable names and indices
409 size_t max_var_name_length = 0;
410 std::map<std::string, TorchIndex> vars;
411 for (auto var : axis.variable_accessors(true))
412 {
414 if (var_name.size() > max_var_name_length)
416 vars.emplace(var_name, axis.indices(var));
417 }
418
419 // Print variables with right alignment
420 for (auto var = vars.begin(); var != vars.end(); var++)
421 {
422 os << std::setw(max_var_name_length) << var->first << ": " << var->second;
423 if (std::next(var) != vars.end())
424 os << std::endl;
425 }
426
427 return os;
428}
429
430void
432 std::ostream & os, int & id, std::string axis_name, bool subgraph, bool node_handle) const
433{
434 // Preemble
435 os << (subgraph ? "subgraph " : "graph ");
436 os << "cluster_" << id++ << " ";
437 os << "{\n";
438 os << "label = \"" << axis_name << "\"\n";
439 os << "bgcolor = lightgrey\n";
440
441 // The axis should have an invisible node so that I can draw arrows
442 if (node_handle)
443 os << "\"" << axis_name << "\" [label = \"\", style = invis]\n";
444
445 // Write all the variables
446 for (const auto & [name, sz] : _variables)
447 {
448 os << "\"" << axis_name + " " + name << "\" ";
449 os << "[style = filled, color = white, shape = Square, ";
450 os << "label = \"" << name + " [" << sz << "]\"]\n";
451 }
452
453 // Write all the subaxes
454 for (const auto & [name, subaxis] : _subaxes)
455 subaxis->to_dot(os, id, axis_name + " " + name, true);
456
457 os << "}\n";
458}
459
460bool
461operator==(const LabeledAxis & a, const LabeledAxis & b)
462{
463 return a.equals(b);
464}
465
466bool
467operator!=(const LabeledAxis & a, const LabeledAxis & b)
468{
469 return !a.equals(b);
470}
471} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:44
LabeledAxisAccessor append(const LabeledAxisAccessor &axis) const
Add a new item.
Definition LabeledAxisAccessor.cxx:72
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
bool equals(const LabeledAxis &other) const
Check to see if two LabeledAxis objects are equivalent.
Definition LabeledAxis.cxx:383
std::vector< LabeledAxisAccessor > merge(LabeledAxis &other)
Merge with another LabeledAxis.
Definition LabeledAxis.cxx:122
LabeledAxis & rename(const std::string &original, const std::string &rename)
Change the label of an item.
Definition LabeledAxis.cxx:68
std::set< LabeledAxisAccessor > variable_accessors(bool recursive=false, const LabeledAxisAccessor &subaxis={}) const
Get the variable accessors.
Definition LabeledAxis.cxx:333
void to_dot(std::ostream &os, int &id, std::string name="", bool subgraph=false, bool node_handle=false) const
Write this object in dot format.
Definition LabeledAxis.cxx:431
LabeledAxis & add(const LabeledAxisAccessor &accessor)
Add a variable or subaxis.
Definition LabeledAxis.h:67
LabeledAxis & clear()
Clear everything.
Definition LabeledAxis.cxx:111
LabeledAxis & remove(const std::string &name)
Remove an item.
Definition LabeledAxis.cxx:94
const LabeledAxis & subaxis(const std::string &name) const
Get a sub-axis.
Definition LabeledAxis.cxx:365
LabeledAxis()
Empty constructor.
Definition LabeledAxis.cxx:29
std::vector< std::string > item_names() const
Get the item names.
Definition LabeledAxis.cxx:324
TorchIndex indices(const LabeledAxisAccessor &accessor) const
Get the indices of a specific item by a LabeledAxisAccessor
Definition LabeledAxis.cxx:240
const std::map< std::string, std::shared_ptr< LabeledAxis > > & subaxes() const
Get the subaxes.
Definition LabeledAxis.h:165
void setup_layout()
Definition LabeledAxis.cxx:158
TorchSize storage_size() const
Get the (total) storage size of this axis.
Definition LabeledAxis.h:142
bool has_subaxis(const LabeledAxisAccessor &s) const
Check the existence of a subaxis by its LabeledAxisAccessor.
Definition LabeledAxis.cxx:200
bool has_variable(const LabeledAxisAccessor &var) const
Does the variable of a given primitive type exist?
Definition LabeledAxis.h:126
std::vector< std::pair< TorchIndex, TorchIndex > > common_indices(const LabeledAxis &other, bool recursive=true) const
Get the common indices of two LabeledAxiss.
Definition LabeledAxis.cxx:262
std::string stringify(const T &t)
Definition utils.h:302
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
int64_t TorchSize
Definition types.h:33
std::ostream & operator<<(std::ostream &os, const OptionCollection &p)
Definition OptionCollection.cxx:37
bool operator==(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:461
bool operator!=(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:467
at::indexing::TensorIndex TorchIndex
Definition types.h:36