NEML2 1.4.0
Loading...
Searching...
No Matches
LabeledAxis.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include <unordered_map>
28#include <type_traits>
29
30#include "neml2/misc/types.h"
31
32#include "neml2/tensors/LabeledAxisAccessor.h"
33#include "neml2/tensors/Scalar.h"
34#include "neml2/tensors/SR2.h"
35
36namespace neml2
37{
55{
56public:
57 typedef std::unordered_map<std::string, std::pair<Size, Size>> AxisLayout;
58
59 // Custom comparator for sorting assembly indices
61 {
62 bool operator()(const indexing::TensorIndex & a, const indexing::TensorIndex & b) const
63 {
64 neml_assert(a.is_slice() && b.is_slice(), "Comparator must be used on slices");
65 neml_assert(a.slice().step().expect_int() == 1 && b.slice().step().expect_int() == 1,
66 "Slices must have step == 1");
67 return a.slice().start().expect_int() < b.slice().start().expect_int();
68 }
69 };
70
73
76
78 template <typename T>
80 {
81 // Add an empty subaxis
82 if constexpr (std::is_same_v<LabeledAxis, T>)
83 {
84 if (!accessor.empty() && !has_subaxis(accessor.slice(0, 1)))
85 {
86 _subaxes.emplace(accessor.vec()[0], std::make_shared<LabeledAxis>());
87 subaxis(accessor.vec()[0]).add<LabeledAxis>(accessor.slice(1));
88 }
89 return *this;
90 }
91 else
92 {
93 // The storage is *flat* -- will need to reshape when we return!
94 // All NEML2 primitive data types will have the member const_base_sizes
95 auto sz = utils::storage_size(T::const_base_sizes);
96 add(accessor, sz);
97 return *this;
98 }
99 }
100
103
105 void clear();
106
110 void setup_layout();
111
114 bool has_state() const { return _has_state; }
115 bool has_old_state() const { return _has_old_state; }
116 bool has_forces() const { return _has_forces; }
117 bool has_old_forces() const { return _has_old_forces; }
118 bool has_residual() const { return _has_residual; }
119 bool has_parameters() const { return _has_parameters; }
121
123 size_t nvariable(bool recursive = true) const;
124
126 size_t nsubaxis(bool recursive = false) const;
127
129 bool has_item(const LabeledAxisAccessor & name) const
130 {
131 return has_variable(name) || has_subaxis(name);
132 }
133
135 template <typename T>
137 {
138 if (!has_variable(var))
139 return false;
140
141 // also check size
142 return storage_size(var) == utils::storage_size(T::const_base_sizes);
143 }
144
146 bool has_variable(const LabeledAxisAccessor & var) const;
147
149 bool has_subaxis(const LabeledAxisAccessor & s) const;
150
152 Size storage_size(const LabeledAxisAccessor & name = {}) const;
153
155 const AxisLayout & layout() const { return _layout; }
156
158 indexing::TensorIndex indices(const LabeledAxisAccessor & accessor) const;
159
161 std::vector<std::pair<indexing::TensorIndex, indexing::TensorIndex>>
162 common_indices(const LabeledAxis & other, bool recursive = true) const;
163
165 std::vector<LabeledAxisAccessor>
166 sort_by_assembly_order(const std::set<LabeledAxisAccessor> &) const;
167
169 const std::map<std::string, Size> & variables() const { return _variables; }
170
172 std::set<LabeledAxisAccessor> variable_names(bool recursive = true) const;
173
175 const std::map<std::string, std::shared_ptr<LabeledAxis>> & subaxes() const { return _subaxes; }
176
178 std::set<LabeledAxisAccessor> subaxis_names(bool recursive = false) const;
179
181 const LabeledAxis & subaxis(const LabeledAxisAccessor & name) const;
182
185
187 bool equals(const LabeledAxis & other) const;
188
189 friend std::ostream & operator<<(std::ostream & os, const LabeledAxis & axis);
190
191private:
192 void add(LabeledAxis & axis,
193 Size sz,
195 const LabeledAxisAccessor::const_iterator & end) const;
196
199 const LabeledAxisAccessor::const_iterator & end) const;
200
203 indexing::TensorIndex indices(Size offset,
205 const LabeledAxisAccessor::const_iterator & end) const;
206
208 void indices(const LabeledAxis & other,
209 bool recursive,
210 bool inclusive,
211 std::vector<Size> & idx,
212 Size offset) const;
213
215 void common_indices(const LabeledAxis & other,
216 bool recursive,
217 std::vector<Size> & idxa,
218 std::vector<Size> & idxb,
220 Size offsetb) const;
221
223 std::map<std::string, Size> _variables;
224
226 // Each sub-axis can contain its own variables and sub-axes
227 std::map<std::string, std::shared_ptr<LabeledAxis>> _subaxes;
228
230 // After all the `LabeledAxis`s are setup, we need to setup the layout once and only once. This is
231 // important for performance considerations, as we need to use the layout to construct many,
232 // many LabeledVector and LabeledMatrix at runtime, and so we don't want to waste time on setting
233 // up the layout over and over again.
234 AxisLayout _layout;
235
237 // Similar considerations as `_layout`, i.e., the _offset will be zero during the setup stage,
238 // and will have a fixed (hopefully correct) size after the layout have been setup.
239 Size _offset;
240
243 bool _has_state;
244 bool _has_old_state;
245 bool _has_forces;
246 bool _has_old_forces;
247 bool _has_residual;
248 bool _has_parameters;
250};
251
252std::ostream & operator<<(std::ostream & os, const LabeledAxis & axis);
253
254bool operator==(const LabeledAxis & a, const LabeledAxis & b);
255
256bool operator!=(const LabeledAxis & a, const LabeledAxis & b);
257} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:47
c10::SmallVector< std::string >::const_iterator const_iterator
Definition LabeledAxisAccessor.h:83
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
bool equals(const LabeledAxis &other) const
Check to see if two LabeledAxis objects are equivalent.
Definition LabeledAxis.cxx:341
size_t nvariable(bool recursive=true) const
Number of variables.
Definition LabeledAxis.cxx:122
std::vector< LabeledAxisAccessor > sort_by_assembly_order(const std::set< LabeledAxisAccessor > &) const
Sort a set of LabeledAxisAccessors by their indices.
Definition LabeledAxis.cxx:268
LabeledAxis & add(const LabeledAxisAccessor &accessor)
Add a variable or subaxis.
Definition LabeledAxis.h:79
bool has_item(const LabeledAxisAccessor &name) const
Does the item exist?
Definition LabeledAxis.h:129
bool has_forces() const
Definition LabeledAxis.h:116
indexing::TensorIndex indices(const LabeledAxisAccessor &accessor) const
Get the indices of a specific item by a LabeledAxisAccessor
Definition LabeledAxis.cxx:186
Size storage_size(const LabeledAxisAccessor &name={}) const
Get the total storage size of this axis or the storage size of an item.
Definition LabeledAxis.cxx:166
std::set< LabeledAxisAccessor > variable_names(bool recursive=true) const
Get the variable names.
Definition LabeledAxis.cxx:284
size_t nsubaxis(bool recursive=false) const
Number of subaxes.
Definition LabeledAxis.cxx:128
friend std::ostream & operator<<(std::ostream &os, const LabeledAxis &axis)
Definition LabeledAxis.cxx:367
bool has_old_forces() const
Definition LabeledAxis.h:117
LabeledAxis()
Empty constructor.
Definition LabeledAxis.cxx:29
std::unordered_map< std::string, std::pair< Size, Size > > AxisLayout
Definition LabeledAxis.h:57
bool has_residual() const
Definition LabeledAxis.h:118
const std::map< std::string, std::shared_ptr< LabeledAxis > > & subaxes() const
Get the subaxes.
Definition LabeledAxis.h:175
const std::map< std::string, Size > & variables() const
Get the variables.
Definition LabeledAxis.h:169
void setup_layout()
Definition LabeledAxis.cxx:90
std::vector< std::pair< indexing::TensorIndex, indexing::TensorIndex > > common_indices(const LabeledAxis &other, bool recursive=true) const
Get the common indices of two LabeledAxiss.
Definition LabeledAxis.cxx:208
const AxisLayout & layout() const
Get the layout.
Definition LabeledAxis.h:155
void clear()
Clear all internal data.
Definition LabeledAxis.cxx:81
bool has_subaxis(const LabeledAxisAccessor &s) const
Check the existence of a subaxis by its LabeledAxisAccessor.
Definition LabeledAxis.cxx:150
std::set< LabeledAxisAccessor > subaxis_names(bool recursive=false) const
Get subaxes' names.
Definition LabeledAxis.cxx:302
bool has_variable(const LabeledAxisAccessor &var) const
Does the variable of a given primitive type exist?
Definition LabeledAxis.h:136
bool has_parameters() const
Definition LabeledAxis.h:119
bool has_old_state() const
Definition LabeledAxis.h:115
bool has_state() const
Definition LabeledAxis.h:114
const LabeledAxis & subaxis(const LabeledAxisAccessor &name) const
Get a sub-axis.
Definition LabeledAxis.cxx:320
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:40
Definition CrossRef.cxx:30
bool operator==(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:393
bool operator!=(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:399
std::ostream & operator<<(std::ostream &os, const EnumSelection &es)
Definition EnumSelection.cxx:31
void neml_assert(bool assertion, Args &&... args)
Definition error.h:64
Definition LabeledAxis.h:61
bool operator()(const indexing::TensorIndex &a, const indexing::TensorIndex &b) const
Definition LabeledAxis.h:62