27#include "neml2/models/Interpolation.h"
96 template <
typename T2>
97 T2 mask(
const T2 &
in,
const torch::Tensor &
m)
const;
112template <
typename T2>
122#define LINEARINTERPOLATION_TYPEDEF_FIXEDDIMTENSOR(T) \
123 typedef LinearInterpolation<T> T##LinearInterpolation
124FOR_ALL_FIXEDDIMTENSOR(LINEARINTERPOLATION_TYPEDEF_FIXEDDIMTENSOR);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The base class for interpolated nonlinear parameter.
Definition Interpolation.h:49
Linearly interpolate the parameter along a single axis.
Definition LinearInterpolation.h:76
LinearInterpolation(const OptionSet &options)
Definition LinearInterpolation.cxx:45
static OptionSet expected_options()
Definition LinearInterpolation.cxx:37
void set_value(bool out, bool dout_din, bool d2out_din2) override
The map between input -> output, and optionally its derivatives.
Definition LinearInterpolation.cxx:63
const torch::TensorOptions & options() const
This model's tensor options.
Definition Model.h:116
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:59
The (logical) scalar.
Definition Scalar.h:38
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
std::vector< TorchSize > TorchShape
Definition types.h:36