25#include "neml2/tensors/LabeledAxis.h"
35 : _variables(
other._variables),
36 _subaxes(
other._subaxes),
37 _layout(
other._layout),
38 _offset(
other._offset)
52 const std::vector<std::string>::const_iterator &
cur,
53 const std::vector<std::string>::const_iterator & end)
const
58 axis._variables.emplace(*
cur,
sz);
72 if (
var != _variables.end())
74 auto sz =
var->second;
75 _variables.erase(
var);
86 _subaxes.emplace(
rename, axis);
97 auto count = _variables.erase(name);
102 count += _subaxes.erase(name);
121std::vector<LabeledAxisAccessor>
131 std::vector<std::string> subaxes,
135 for (
const auto & [name,
sz] :
other._variables)
138 _variables.emplace(name,
sz);
145 for (
auto & [name,
subaxis] : other._subaxes)
147 auto found = _subaxes.find(name);
148 if (found == _subaxes.end())
149 _subaxes.emplace(name, std::make_shared<LabeledAxis>());
152 new_subaxes.push_back(name);
153 _subaxes[name]->merge(*
subaxis, new_subaxes, merged_vars);
164 for (
auto & [name,
sz] : _variables)
166 std::pair<TorchSize, TorchSize>
range = {_offset, _offset +
sz};
167 _layout.emplace(name,
range);
172 for (
auto & [name, axis] : _subaxes)
176 std::pair<TorchSize, TorchSize>
range = {_offset, _offset + axis->
storage_size()};
177 _layout.emplace(name,
range);
188 if (
var.vec().size() > 1)
196 return _variables.count(
var.vec()[0]);
205 if (
s.vec().size() > 1)
208 return subaxis(
s.vec()[0]).has_subaxis(
s.slice(1));
213 return _subaxes.count(
s.vec()[0]);
224 const std::vector<std::string>::const_iterator & end)
const
228 if (_variables.count(*
cur))
229 return _variables.at(*
cur);
230 else if (_subaxes.count(*
cur))
231 return _subaxes.at(*cur)->storage_size();
233 neml_assert_dbg(
false,
"Trying to find the storage size of a non-existent item named ", *
cur);
236 return subaxis(*cur).storage_size(cur + 1, end);
243 return torch::indexing::Slice();
250 const std::vector<std::string>::const_iterator &
cur,
251 const std::vector<std::string>::const_iterator & end)
const
261std::vector<std::pair<TorchIndex, TorchIndex>>
264 using namespace torch::indexing;
266 std::vector<std::pair<TorchIndex, TorchIndex>>
indices;
267 std::vector<TorchSize>
idxa;
268 std::vector<TorchSize>
idxb;
277 while (
j <
idxa.size() - 1)
296 std::vector<TorchSize> &
idxa,
297 std::vector<TorchSize> &
idxb,
301 for (
const auto & [name,
sz] : _variables)
313 for (
const auto & [name, axis] : _subaxes)
319 offseta + _layout.at(name).first,
320 offsetb + other._layout.at(name).first);
323std::vector<std::string>
326 std::vector<std::string>
names;
327 for (
const auto &
item : _layout)
332std::set<LabeledAxisAccessor>
346 for (
auto &
var : _variables)
349 var_accessor = var_accessor.on(cur);
351 accessors.insert(var_accessor);
353 accessors.insert(var_accessor);
357 for (
auto & [name, axis] : _subaxes)
359 auto next = cur.
append(name);
368 _subaxes.count(name),
"In LabeledAxis::subaxis, no subaxis matches given name ", name);
370 return *_subaxes.at(name);
377 _subaxes.count(name),
"In LabeledAxis::subaxis, no subaxis matches given name ", name);
379 return *_subaxes.at(name);
386 if (_offset !=
other._offset)
392 if (_variables !=
other._variables)
396 for (
auto & [name, axis] : _subaxes)
397 if (
other._subaxes.count(name) == 0)
399 else if (*
other._subaxes.at(name) != *axis)
410 std::map<std::string, TorchIndex>
vars;
423 if (std::next(
var) !=
vars.end())
436 os <<
"cluster_" <<
id++ <<
" ";
439 os <<
"bgcolor = lightgrey\n";
443 os <<
"\"" <<
axis_name <<
"\" [label = \"\", style = invis]\n";
446 for (
const auto & [name,
sz] : _variables)
449 os <<
"[style = filled, color = white, shape = Square, ";
450 os <<
"label = \"" << name +
" [" <<
sz <<
"]\"]\n";
454 for (
const auto & [name,
subaxis] : _subaxes)
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:44
LabeledAxisAccessor append(const LabeledAxisAccessor &axis) const
Add a new item.
Definition LabeledAxisAccessor.cxx:72
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
bool equals(const LabeledAxis &other) const
Check to see if two LabeledAxis objects are equivalent.
Definition LabeledAxis.cxx:383
std::vector< LabeledAxisAccessor > merge(LabeledAxis &other)
Merge with another LabeledAxis.
Definition LabeledAxis.cxx:122
LabeledAxis & rename(const std::string &original, const std::string &rename)
Change the label of an item.
Definition LabeledAxis.cxx:68
std::set< LabeledAxisAccessor > variable_accessors(bool recursive=false, const LabeledAxisAccessor &subaxis={}) const
Get the variable accessors.
Definition LabeledAxis.cxx:333
void to_dot(std::ostream &os, int &id, std::string name="", bool subgraph=false, bool node_handle=false) const
Write this object in dot format.
Definition LabeledAxis.cxx:431
LabeledAxis & add(const LabeledAxisAccessor &accessor)
Add a variable or subaxis.
Definition LabeledAxis.h:67
LabeledAxis & clear()
Clear everything.
Definition LabeledAxis.cxx:111
LabeledAxis & remove(const std::string &name)
Remove an item.
Definition LabeledAxis.cxx:94
const LabeledAxis & subaxis(const std::string &name) const
Get a sub-axis.
Definition LabeledAxis.cxx:365
LabeledAxis()
Empty constructor.
Definition LabeledAxis.cxx:29
std::vector< std::string > item_names() const
Get the item names.
Definition LabeledAxis.cxx:324
TorchIndex indices(const LabeledAxisAccessor &accessor) const
Get the indices of a specific item by a LabeledAxisAccessor
Definition LabeledAxis.cxx:240
const std::map< std::string, std::shared_ptr< LabeledAxis > > & subaxes() const
Get the subaxes.
Definition LabeledAxis.h:165
void setup_layout()
Definition LabeledAxis.cxx:158
TorchSize storage_size() const
Get the (total) storage size of this axis.
Definition LabeledAxis.h:142
bool has_subaxis(const LabeledAxisAccessor &s) const
Check the existence of a subaxis by its LabeledAxisAccessor.
Definition LabeledAxis.cxx:200
bool has_variable(const LabeledAxisAccessor &var) const
Does the variable of a given primitive type exist?
Definition LabeledAxis.h:126
std::vector< std::pair< TorchIndex, TorchIndex > > common_indices(const LabeledAxis &other, bool recursive=true) const
Get the common indices of two LabeledAxiss.
Definition LabeledAxis.cxx:262
std::string stringify(const T &t)
Definition utils.h:302
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
int64_t TorchSize
Definition types.h:35
std::ostream & operator<<(std::ostream &os, const OptionCollection &p)
Definition OptionCollection.cxx:37
bool operator==(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:461
bool operator!=(const LabeledAxis &a, const LabeledAxis &b)
Definition LabeledAxis.cxx:467
at::indexing::TensorIndex TorchIndex
Definition types.h:38