NEML2 1.4.0
Loading...
Searching...
No Matches
LabeledAxis.h
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <unordered_map>
28#include <type_traits>
29
30#include "neml2/misc/types.h"
31
32#include "neml2/tensors/LabeledAxisAccessor.h"
33#include "neml2/tensors/Scalar.h"
34#include "neml2/tensors/SR2.h"
35
36namespace neml2
37{
55{
56public:
57 typedef std::unordered_map<std::string, std::pair<TorchSize, TorchSize>> AxisLayout;
58
61
64
66 template <typename T>
68 {
69 // Add an empty subaxis
70 if constexpr (std::is_same_v<LabeledAxis, T>)
71 {
72 if (!accessor.empty() && !has_subaxis(accessor.slice(0, 1)))
73 {
74 _subaxes.emplace(accessor.vec()[0], std::make_shared<LabeledAxis>());
75 subaxis(accessor.vec()[0]).add<LabeledAxis>(accessor.slice(1));
76 }
77 return *this;
78 }
79 else
80 {
81 // The storage is *flat* -- will need to reshape when we return!
82 // All NEML2 primitive data types will have the member const_base_sizes
83 auto sz = utils::storage_size(T::const_base_sizes);
84 add(accessor, sz);
85 return *this;
86 }
87 }
88
91
93 LabeledAxis & rename(const std::string & original, const std::string & rename);
94
96 LabeledAxis & remove(const std::string & name);
97
100
102 std::vector<LabeledAxisAccessor> merge(LabeledAxis & other);
103
107 void setup_layout();
108
110 size_t nitem() const { return nvariable() + nsubaxis(); }
111
113 size_t nvariable() const { return _variables.size(); }
114
116 size_t nsubaxis() const { return _subaxes.size(); }
117
119 bool has_item(const LabeledAxisAccessor & name) const
120 {
121 return has_variable(name) || has_subaxis(name);
122 }
123
125 template <typename T>
127 {
128 if (!has_variable(var))
129 return false;
130
131 // also check size
132 return storage_size(var) == utils::storage_size(T::const_base_sizes);
133 }
134
136 bool has_variable(const LabeledAxisAccessor & var) const;
137
139 bool has_subaxis(const LabeledAxisAccessor & s) const;
140
142 TorchSize storage_size() const { return _offset; }
145
147 const AxisLayout & layout() const { return _layout; }
148
152 TorchIndex indices(const LabeledAxis & other, bool recursive = true, bool inclusive = true) const;
153
155 std::vector<std::pair<TorchIndex, TorchIndex>> common_indices(const LabeledAxis & other,
156 bool recursive = true) const;
157
159 std::vector<std::string> item_names() const;
160
162 const std::map<std::string, TorchSize> & variables() const { return _variables; }
163
165 const std::map<std::string, std::shared_ptr<LabeledAxis>> & subaxes() const { return _subaxes; }
166
168 std::set<LabeledAxisAccessor> variable_accessors(bool recursive = false,
169 const LabeledAxisAccessor & subaxis = {}) const;
170
172 const LabeledAxis & subaxis(const std::string & name) const;
174 LabeledAxis & subaxis(const std::string & name);
175
177 bool equals(const LabeledAxis & other) const;
178
179 friend std::ostream & operator<<(std::ostream & os, const LabeledAxis & info);
180
182 void to_dot(std::ostream & os,
183 int & id,
184 std::string name = "",
185 bool subgraph = false,
186 bool node_handle = false) const;
187
188private:
189 void add(LabeledAxis & axis,
190 TorchSize sz,
191 const std::vector<std::string>::const_iterator & cur,
192 const std::vector<std::string>::const_iterator & end) const;
193
194 void merge(LabeledAxis & other,
195 std::vector<std::string> subaxes,
196 std::vector<LabeledAxisAccessor> & merged_vars);
197
199 TorchSize storage_size(const std::vector<std::string>::const_iterator & cur,
200 const std::vector<std::string>::const_iterator & end) const;
201
205 const std::vector<std::string>::const_iterator & cur,
206 const std::vector<std::string>::const_iterator & end) const;
207
209 void indices(const LabeledAxis & other,
210 bool recursive,
211 bool inclusive,
212 std::vector<TorchSize> & idx,
213 TorchSize offset) const;
214
216 void common_indices(const LabeledAxis & other,
217 bool recursive,
218 std::vector<TorchSize> & idxa,
219 std::vector<TorchSize> & idxb,
220 TorchSize offseta,
221 TorchSize offsetb) const;
222
223 void variable_accessors(std::set<LabeledAxisAccessor> & accessors,
224 LabeledAxisAccessor cur,
225 bool recursive,
226 const LabeledAxisAccessor & subaxis) const;
227
229 std::map<std::string, TorchSize> _variables;
230
232 // Each sub-axis can contain its own variables and sub-axes
233 std::map<std::string, std::shared_ptr<LabeledAxis>> _subaxes;
234
236 // After all the `LabeledAxis`s are setup, we need to setup the layout once and only once. This is
237 // important for performance considerations, as we need to use the layout to construct many,
238 // many LabeledVector and LabeledMatrix at runtime, and so we don't want to waste time on setting
239 // up the layout over and over again.
240 AxisLayout _layout;
241
243 // Similar considerations as `_layout`, i.e., the _offset will be zero during the setup stage,
244 // and will have a fixed (hopefully correct) size after the layout have been setup.
245 TorchSize _offset;
246};
247
248std::ostream & operator<<(std::ostream & os, const LabeledAxis & info);
249
250bool operator==(const LabeledAxis & a, const LabeledAxis & b);
251
252bool operator!=(const LabeledAxis & a, const LabeledAxis & b);
253} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:44
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
size_t nvariable() const
Number of variables.
Definition LabeledAxis.h:113
bool equals(const LabeledAxis &other) const
Check to see if two LabeledAxis objects are equivalent.
Definition LabeledAxis.cxx:383
std::vector< LabeledAxisAccessor > merge(LabeledAxis &other)
Merge with another LabeledAxis.
Definition LabeledAxis.cxx:122
LabeledAxis & rename(const std::string &original, const std::string &rename)
Change the label of an item.
Definition LabeledAxis.cxx:68
std::set< LabeledAxisAccessor > variable_accessors(bool recursive=false, const LabeledAxisAccessor &subaxis={}) const
Get the variable accessors.
Definition LabeledAxis.cxx:333
void to_dot(std::ostream &os, int &id, std::string name="", bool subgraph=false, bool node_handle=false) const
Write this object in dot format.
Definition LabeledAxis.cxx:431
LabeledAxis & add(const LabeledAxisAccessor &accessor)
Add a variable or subaxis.
Definition LabeledAxis.h:67
size_t nsubaxis() const
Number of subaxes.
Definition LabeledAxis.h:116
bool has_item(const LabeledAxisAccessor &name) const
Does the item exist?
Definition LabeledAxis.h:119
LabeledAxis & clear()
Clear everything.
Definition LabeledAxis.cxx:111
const std::map< std::string, TorchSize > & variables() const
Get the variables.
Definition LabeledAxis.h:162
LabeledAxis & remove(const std::string &name)
Remove an item.
Definition LabeledAxis.cxx:94
const LabeledAxis & subaxis(const std::string &name) const
Get a sub-axis.
Definition LabeledAxis.cxx:365
std::unordered_map< std::string, std::pair< TorchSize, TorchSize > > AxisLayout
Definition LabeledAxis.h:57
LabeledAxis()
Empty constructor.
Definition LabeledAxis.cxx:29
std::vector< std::string > item_names() const
Get the item names.
Definition LabeledAxis.cxx:324
TorchIndex indices(const LabeledAxisAccessor &accessor) const
Get the indices of a specific item by a LabeledAxisAccessor
Definition LabeledAxis.cxx:240
const std::map< std::string, std::shared_ptr< LabeledAxis > > & subaxes() const
Get the subaxes.
Definition LabeledAxis.h:165
void setup_layout()
Definition LabeledAxis.cxx:158
friend std::ostream & operator<<(std::ostream &os, const LabeledAxis &info)
Definition LabeledAxis.cxx:406
TorchIndex indices(const LabeledAxis &other, bool recursive=true, bool inclusive=true) const
Get the indices using another LabeledAxis.
const AxisLayout & layout() const
Get the layout.
Definition LabeledAxis.h:147
TorchSize storage_size() const
Get the (total) storage size of this axis.
Definition LabeledAxis.h:142
bool has_subaxis(const LabeledAxisAccessor &s) const
Check the existence of a subaxis by its LabeledAxisAccessor.
Definition LabeledAxis.cxx:200
bool has_variable(const LabeledAxisAccessor &var) const
Does the variable of a given primitive type exist?
Definition LabeledAxis.h:126
size_t nitem() const
Number of items.
Definition LabeledAxis.h:110
std::vector< std::pair< TorchIndex, TorchIndex > > common_indices(const LabeledAxis &other, bool recursive=true) const
Get the common indices of two LabeledAxiss.
Definition LabeledAxis.cxx:262
TorchSize storage_size(TorchShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:32
Definition CrossRef.cxx:32
int64_t TorchSize
Definition types.h:33
std::ostream & operator<<(std::ostream &os, const OptionCollection &p)
Definition OptionCollection.cxx:37
bool operator==(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:461
bool operator!=(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:467
at::indexing::TensorIndex TorchIndex
Definition types.h:36