27#include <unordered_map>
30#include "neml2/misc/types.h"
32#include "neml2/tensors/LabeledAxisAccessor.h"
33#include "neml2/tensors/Scalar.h"
34#include "neml2/tensors/SR2.h"
57 typedef std::unordered_map<std::string, std::pair<TorchSize, TorchSize>>
AxisLayout;
70 if constexpr (std::is_same_v<LabeledAxis, T>)
74 _subaxes.emplace(
accessor.vec()[0], std::make_shared<LabeledAxis>());
116 size_t nsubaxis()
const {
return _subaxes.size(); }
125 template <
typename T>
162 const std::map<std::string, TorchSize> &
variables()
const {
return _variables; }
165 const std::map<std::string, std::shared_ptr<LabeledAxis>> &
subaxes()
const {
return _subaxes; }
182 void to_dot(std::ostream & os,
184 std::string name =
"",
185 bool subgraph =
false,
186 bool node_handle =
false)
const;
191 const std::vector<std::string>::const_iterator & cur,
192 const std::vector<std::string>::const_iterator & end)
const;
195 std::vector<std::string>
subaxes,
196 std::vector<LabeledAxisAccessor> & merged_vars);
200 const std::vector<std::string>::const_iterator & end)
const;
205 const std::vector<std::string>::const_iterator & cur,
206 const std::vector<std::string>::const_iterator & end)
const;
212 std::vector<TorchSize> & idx,
218 std::vector<TorchSize> & idxa,
219 std::vector<TorchSize> & idxb,
224 LabeledAxisAccessor cur,
226 const LabeledAxisAccessor &
subaxis)
const;
229 std::map<std::string, TorchSize> _variables;
233 std::map<std::string, std::shared_ptr<LabeledAxis>> _subaxes;
248std::ostream &
operator<<(std::ostream & os,
const LabeledAxis & info);
250bool operator==(
const LabeledAxis & a,
const LabeledAxis & b);
252bool operator!=(
const LabeledAxis & a,
const LabeledAxis & b);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:44
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
size_t nvariable() const
Number of variables.
Definition LabeledAxis.h:113
bool equals(const LabeledAxis &other) const
Check to see if two LabeledAxis objects are equivalent.
Definition LabeledAxis.cxx:383
std::vector< LabeledAxisAccessor > merge(LabeledAxis &other)
Merge with another LabeledAxis.
Definition LabeledAxis.cxx:122
LabeledAxis & rename(const std::string &original, const std::string &rename)
Change the label of an item.
Definition LabeledAxis.cxx:68
std::set< LabeledAxisAccessor > variable_accessors(bool recursive=false, const LabeledAxisAccessor &subaxis={}) const
Get the variable accessors.
Definition LabeledAxis.cxx:333
void to_dot(std::ostream &os, int &id, std::string name="", bool subgraph=false, bool node_handle=false) const
Write this object in dot format.
Definition LabeledAxis.cxx:431
LabeledAxis & add(const LabeledAxisAccessor &accessor)
Add a variable or subaxis.
Definition LabeledAxis.h:67
size_t nsubaxis() const
Number of subaxes.
Definition LabeledAxis.h:116
bool has_item(const LabeledAxisAccessor &name) const
Does the item exist?
Definition LabeledAxis.h:119
LabeledAxis & clear()
Clear everything.
Definition LabeledAxis.cxx:111
const std::map< std::string, TorchSize > & variables() const
Get the variables.
Definition LabeledAxis.h:162
LabeledAxis & remove(const std::string &name)
Remove an item.
Definition LabeledAxis.cxx:94
const LabeledAxis & subaxis(const std::string &name) const
Get a sub-axis.
Definition LabeledAxis.cxx:365
std::unordered_map< std::string, std::pair< TorchSize, TorchSize > > AxisLayout
Definition LabeledAxis.h:57
LabeledAxis()
Empty constructor.
Definition LabeledAxis.cxx:29
std::vector< std::string > item_names() const
Get the item names.
Definition LabeledAxis.cxx:324
TorchIndex indices(const LabeledAxisAccessor &accessor) const
Get the indices of a specific item by a LabeledAxisAccessor
Definition LabeledAxis.cxx:240
const std::map< std::string, std::shared_ptr< LabeledAxis > > & subaxes() const
Get the subaxes.
Definition LabeledAxis.h:165
void setup_layout()
Definition LabeledAxis.cxx:158
friend std::ostream & operator<<(std::ostream &os, const LabeledAxis &info)
Definition LabeledAxis.cxx:406
TorchIndex indices(const LabeledAxis &other, bool recursive=true, bool inclusive=true) const
Get the indices using another LabeledAxis.
const AxisLayout & layout() const
Get the layout.
Definition LabeledAxis.h:147
TorchSize storage_size() const
Get the (total) storage size of this axis.
Definition LabeledAxis.h:142
bool has_subaxis(const LabeledAxisAccessor &s) const
Check the existence of a subaxis by its LabeledAxisAccessor.
Definition LabeledAxis.cxx:200
bool has_variable(const LabeledAxisAccessor &var) const
Does the variable of a given primitive type exist?
Definition LabeledAxis.h:126
size_t nitem() const
Number of items.
Definition LabeledAxis.h:110
std::vector< std::pair< TorchIndex, TorchIndex > > common_indices(const LabeledAxis &other, bool recursive=true) const
Get the common indices of two LabeledAxiss.
Definition LabeledAxis.cxx:262
TorchSize storage_size(TorchShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:32
Definition CrossRef.cxx:32
int64_t TorchSize
Definition types.h:35
std::ostream & operator<<(std::ostream &os, const OptionCollection &p)
Definition OptionCollection.cxx:37
bool operator==(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:461
bool operator!=(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:467
at::indexing::TensorIndex TorchIndex
Definition types.h:38