25#include "neml2/models/SumModel.h"
26#include "neml2/tensors/SSR4.h"
42 options.
doc() =
"Calculate linear combination of multiple " +
tensor_type +
43 " tensors as \\f$ u = c_i v_i \\f$ (Einstein summation assumed), where \\f$ c_i "
44 "\\f$ are the coefficients, and \\f$ v_i \\f$ are the variables to be summed.";
46 options.
set<std::vector<VariableName>>(
"from_var");
47 options.
set(
"from_var").doc() =
tensor_type +
" tensors to be summed";
50 options.
set(
"to_var").doc() =
"The sum";
52 options.
set<std::vector<CrossRef<Scalar>>>(
"coefficients") = {};
53 options.
set(
"coefficients").doc() =
"Weights associated with each variable";
61 _to(declare_output_variable<T>(
"to_var"))
63 for (
auto fv :
options.get<std::vector<VariableName>>(
"from_var"))
70 const auto coefs_in =
options.get<std::vector<CrossRef<Scalar>>>(
"coefficients");
71 const auto N =
_from.size();
73 _coefs = std::vector<const Scalar *>(
80 "Number of coefficients must be 0, 1, or N, where N is the number of 'from_var'.");
82 for (
size_t i = 0;
i <
N;
i++)
91 const auto N = _from.size();
95 auto sum = T::zeros(_to.batch_sizes(), options());
96 for (
size_t i = 0;
i <
N;
i++)
97 sum += (*_coefs[
i]) * (*_from[
i]);
102 for (
size_t i = 0;
i <
N;
i++)
103 _to.d(*_from[
i]) = (*_coefs[
i]) * T::identity_map(options());
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:44
The base class for all constitutive models.
Definition Model.h:53
const torch::TensorOptions & options() const
This model's tensor options.
Definition Model.h:116
static OptionSet expected_options()
Definition Model.cxx:33
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:59
const std::string & doc() const
A readonly reference to the option set's docstring.
Definition OptionSet.h:91
T & set(const std::string &)
Definition OptionSet.h:436
const T & declare_parameter(const std::string &name, const T &rawval)
Declare a parameter.
Definition ParameterStore.h:145
The (logical) scalar.
Definition Scalar.h:38
std::vector< const Scalar * > _coefs
Scaling coefficient for each term.
Definition SumModel.h:49
std::vector< const Variable< T > * > _from
The input variables (to be summed)
Definition SumModel.h:46
static OptionSet expected_options()
Definition SumModel.cxx:35
SumModel(const OptionSet &options)
Definition SumModel.cxx:59
void set_value(bool out, bool dout_din, bool d2out_din2) override
The map between input -> output, and optionally its derivatives.
Definition SumModel.cxx:89
std::string stringify(const T &t)
Definition utils.h:302
std::string demangle(const char *name)
Definition parser_utils.cxx:46
Definition CrossRef.cxx:32
const torch::TensorOptions default_tensor_options()
Definition types.cxx:30
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73