27#include <unordered_map>
30#include "neml2/misc/types.h"
32#include "neml2/tensors/LabeledAxisAccessor.h"
33#include "neml2/tensors/Scalar.h"
34#include "neml2/tensors/SR2.h"
57 typedef std::unordered_map<std::string, std::pair<Size, Size>>
AxisLayout;
62 bool operator()(
const indexing::TensorIndex & a,
const indexing::TensorIndex & b)
const
64 neml_assert(a.is_slice() && b.is_slice(),
"Comparator must be used on slices");
65 neml_assert(a.slice().step().expect_int() == 1 && b.slice().step().expect_int() == 1,
66 "Slices must have step == 1");
67 return a.slice().start().expect_int() < b.slice().start().expect_int();
82 if constexpr (std::is_same_v<LabeledAxis, T>)
86 _subaxes.emplace(
accessor.vec()[0], std::make_shared<LabeledAxis>());
135 template <
typename T>
161 std::vector<std::pair<indexing::TensorIndex, indexing::TensorIndex>>
165 std::vector<LabeledAxisAccessor>
169 const std::map<std::string, Size> &
variables()
const {
return _variables; }
175 const std::map<std::string, std::shared_ptr<LabeledAxis>> &
subaxes()
const {
return _subaxes; }
211 std::vector<Size> &
idx,
217 std::vector<Size> &
idxa,
218 std::vector<Size> &
idxb,
223 std::map<std::string, Size> _variables;
227 std::map<std::string, std::shared_ptr<LabeledAxis>> _subaxes;
246 bool _has_old_forces;
248 bool _has_parameters;
252std::ostream &
operator<<(std::ostream & os,
const LabeledAxis & axis);
254bool operator==(
const LabeledAxis & a,
const LabeledAxis & b);
256bool operator!=(
const LabeledAxis & a,
const LabeledAxis & b);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:47
c10::SmallVector< std::string >::const_iterator const_iterator
Definition LabeledAxisAccessor.h:83
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
bool equals(const LabeledAxis &other) const
Check to see if two LabeledAxis objects are equivalent.
Definition LabeledAxis.cxx:341
size_t nvariable(bool recursive=true) const
Number of variables.
Definition LabeledAxis.cxx:122
std::vector< LabeledAxisAccessor > sort_by_assembly_order(const std::set< LabeledAxisAccessor > &) const
Sort a set of LabeledAxisAccessors by their indices.
Definition LabeledAxis.cxx:268
LabeledAxis & add(const LabeledAxisAccessor &accessor)
Add a variable or subaxis.
Definition LabeledAxis.h:79
bool has_item(const LabeledAxisAccessor &name) const
Does the item exist?
Definition LabeledAxis.h:129
bool has_forces() const
Definition LabeledAxis.h:116
indexing::TensorIndex indices(const LabeledAxisAccessor &accessor) const
Get the indices of a specific item by a LabeledAxisAccessor
Definition LabeledAxis.cxx:186
Size storage_size(const LabeledAxisAccessor &name={}) const
Get the total storage size of this axis or the storage size of an item.
Definition LabeledAxis.cxx:166
std::set< LabeledAxisAccessor > variable_names(bool recursive=true) const
Get the variable names.
Definition LabeledAxis.cxx:284
size_t nsubaxis(bool recursive=false) const
Number of subaxes.
Definition LabeledAxis.cxx:128
friend std::ostream & operator<<(std::ostream &os, const LabeledAxis &axis)
Definition LabeledAxis.cxx:367
bool has_old_forces() const
Definition LabeledAxis.h:117
LabeledAxis()
Empty constructor.
Definition LabeledAxis.cxx:29
std::unordered_map< std::string, std::pair< Size, Size > > AxisLayout
Definition LabeledAxis.h:57
bool has_residual() const
Definition LabeledAxis.h:118
const std::map< std::string, std::shared_ptr< LabeledAxis > > & subaxes() const
Get the subaxes.
Definition LabeledAxis.h:175
const std::map< std::string, Size > & variables() const
Get the variables.
Definition LabeledAxis.h:169
void setup_layout()
Definition LabeledAxis.cxx:90
std::vector< std::pair< indexing::TensorIndex, indexing::TensorIndex > > common_indices(const LabeledAxis &other, bool recursive=true) const
Get the common indices of two LabeledAxiss.
Definition LabeledAxis.cxx:208
const AxisLayout & layout() const
Get the layout.
Definition LabeledAxis.h:155
void clear()
Clear all internal data.
Definition LabeledAxis.cxx:81
bool has_subaxis(const LabeledAxisAccessor &s) const
Check the existence of a subaxis by its LabeledAxisAccessor.
Definition LabeledAxis.cxx:150
std::set< LabeledAxisAccessor > subaxis_names(bool recursive=false) const
Get subaxes' names.
Definition LabeledAxis.cxx:302
bool has_variable(const LabeledAxisAccessor &var) const
Does the variable of a given primitive type exist?
Definition LabeledAxis.h:136
bool has_parameters() const
Definition LabeledAxis.h:119
bool has_old_state() const
Definition LabeledAxis.h:115
bool has_state() const
Definition LabeledAxis.h:114
const LabeledAxis & subaxis(const LabeledAxisAccessor &name) const
Get a sub-axis.
Definition LabeledAxis.cxx:320
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:40
Definition CrossRef.cxx:30
bool operator==(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:393
bool operator!=(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:399
std::ostream & operator<<(std::ostream &os, const EnumSelection &es)
Definition EnumSelection.cxx:31
void neml_assert(bool assertion, Args &&... args)
Definition error.h:64
Definition LabeledAxis.h:61
bool operator()(const indexing::TensorIndex &a, const indexing::TensorIndex &b) const
Definition LabeledAxis.h:62