NEML2 1.4.0
Loading...
Searching...
No Matches
LinearInterpolation.cxx
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#include "neml2/models/LinearInterpolation.h"
26
27using namespace torch::indexing;
28
29namespace neml2
30{
31#define LINEARINTERPOLATION_REGISTER(T) \
32 register_NEML2_object_alias(T##LinearInterpolation, #T "LinearInterpolation")
33FOR_ALL_FIXEDDIMTENSOR(LINEARINTERPOLATION_REGISTER);
34
35template <typename T>
38{
40 options.doc() += " This object performs a _linear interpolation_.";
41 return options;
42}
43
44template <typename T>
46 : Interpolation<T>(options),
47 _interp_batch_sizes(
48 utils::broadcast_sizes(this->_X.batch_sizes().slice(0, this->_X.batch_dim() - 1),
49 this->_Y.batch_sizes().slice(0, this->_Y.batch_dim() - 1))),
50 _X0(this->template declare_buffer<Scalar>("X0",
51 this->_X.batch_index({Ellipsis, Slice(None, -1)}))),
52 _X1(this->template declare_buffer<Scalar>("X1",
53 this->_X.batch_index({Ellipsis, Slice(1, None)}))),
54 _Y0(this->template declare_buffer<T>("Y0", this->_Y.batch_index({Ellipsis, Slice(None, -1)}))),
55 _slope(this->template declare_buffer<T>("S",
56 math::diff(this->_Y, 1, this->_Y.batch_dim() - 1) /
57 math::diff(this->_X, 1, this->_X.batch_dim() - 1)))
58{
59}
60
61template <typename T>
62void
64{
65 const auto x = Scalar(this->_x);
66 const auto loc = torch::logical_and(torch::gt(x.batch_unsqueeze(-1), _X0),
67 torch::le(x.batch_unsqueeze(-1), _X1));
68 const auto si = mask<T>(_slope, loc);
69
70 if (out)
71 {
72 const auto X0i = mask<Scalar>(_X0, loc);
73 const auto Y0i = mask<T>(_Y0, loc);
74 this->_p = Y0i + si * (x - X0i);
75 }
76
77 if (dout_din)
78 this->_p.d(this->_x) = si;
79
80 if (d2out_din2)
81 {
82 // zero
83 }
84}
85
86#define LINEARINTERPOLATION_INSTANTIATE_FIXEDDIMTENSOR(T) template class LinearInterpolation<T>
87FOR_ALL_FIXEDDIMTENSOR(LINEARINTERPOLATION_INSTANTIATE_FIXEDDIMTENSOR);
88} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The base class for interpolated nonlinear parameter.
Definition Interpolation.h:49
static OptionSet expected_options()
Definition Interpolation.h:68
Linearly interpolate the parameter along a single axis.
Definition LinearInterpolation.h:76
LinearInterpolation(const OptionSet &options)
Definition LinearInterpolation.cxx:45
static OptionSet expected_options()
Definition LinearInterpolation.cxx:37
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:59
const std::string & doc() const
A readonly reference to the option set's docstring.
Definition OptionSet.h:91
The (logical) scalar.
Definition Scalar.h:38
Definition CrossRef.cxx:32