25#include "neml2/models/LinearInterpolation.h"
27using namespace torch::indexing;
31#define LINEARINTERPOLATION_REGISTER(T) \
32 register_NEML2_object_alias(T##LinearInterpolation, #T "LinearInterpolation")
33FOR_ALL_FIXEDDIMTENSOR(LINEARINTERPOLATION_REGISTER);
40 options.
doc() +=
" This object performs a _linear interpolation_.";
48 utils::broadcast_sizes(
this->_X.batch_sizes().slice(0,
this->_X.batch_dim() - 1),
49 this->_Y.batch_sizes().slice(0,
this->_Y.batch_dim() - 1))),
54 _Y0(this->
template declare_buffer<T>(
"Y0", this->_Y.batch_index({Ellipsis, Slice(None, -1)}))),
55 _slope(this->
template declare_buffer<T>(
"S",
56 math::diff(this->_Y, 1, this->_Y.batch_dim() - 1) /
57 math::diff(this->_X, 1, this->_X.batch_dim() - 1)))
65 const auto x =
Scalar(this->_x);
66 const auto loc = torch::logical_and(torch::gt(
x.batch_unsqueeze(-1), _X0),
67 torch::le(
x.batch_unsqueeze(-1), _X1));
78 this->_p.d(this->_x) =
si;
86#define LINEARINTERPOLATION_INSTANTIATE_FIXEDDIMTENSOR(T) template class LinearInterpolation<T>
87FOR_ALL_FIXEDDIMTENSOR(LINEARINTERPOLATION_INSTANTIATE_FIXEDDIMTENSOR);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The base class for interpolated nonlinear parameter.
Definition Interpolation.h:49
static OptionSet expected_options()
Definition Interpolation.h:68
Linearly interpolate the parameter along a single axis.
Definition LinearInterpolation.h:76
LinearInterpolation(const OptionSet &options)
Definition LinearInterpolation.cxx:45
static OptionSet expected_options()
Definition LinearInterpolation.cxx:37
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:59
const std::string & doc() const
A readonly reference to the option set's docstring.
Definition OptionSet.h:91
The (logical) scalar.
Definition Scalar.h:38
Definition CrossRef.cxx:32