NEML2 1.4.0
Loading...
Searching...
No Matches
ParameterStore.cxx
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#include "neml2/models/ParameterStore.h"
26#include "neml2/models/NonlinearParameter.h"
27#include "neml2/tensors/macros.h"
28#include "neml2/tensors/Variable.h"
29
30namespace neml2
31{
33 : _object(object),
34 _options(options)
35{
36}
37
40{
41 neml_assert(_object->host() == _object,
42 "named_parameters() should only be called on the host model.");
43 return _param_values;
44}
45
46void
47ParameterStore::send_parameters_to(const torch::TensorOptions & options)
48{
49 neml_assert(_object->host() == _object, "This method should only be called on the host model.");
50
51 for (auto && [name, param] : _param_values)
52 param.to(options);
53}
54
55const VariableBase *
56ParameterStore::nl_param(const std::string & name) const
57{
58 return _nl_params.count(name) ? _nl_params.at(name) : nullptr;
59}
60
61template <typename T, typename>
62const T &
63ParameterStore::declare_parameter(const std::string & name, const std::string & input_option_name)
64{
65 if (_options.contains<T>(input_option_name))
66 return declare_parameter(name, _options.get<T>(input_option_name));
67 else if (_options.contains<CrossRef<T>>(input_option_name))
68 {
69 try
70 {
71 return declare_parameter(name, T(_options.get<CrossRef<T>>(input_option_name)));
72 }
73 catch (const NEMLException & e1)
74 {
75 try
76 {
77 // Handle the case of *nonlinear* parameter.
78 // Note that nonlinear parameter should only exist inside a Model.
79 auto model = dynamic_cast<Model *>(this);
81 "Trying to declare a parameter named ",
82 name,
83 ". It is not a plain tensor value nor a cross-referenced parameter "
84 "value. Hence I am guessing you are declaring a *nonlinear* parameter. "
85 "However, nonlinear parameter should only be declared by a model, and this "
86 "object does not appear to be a model.");
87
88 auto & nl_param = Factory::get_object<NonlinearParameter<T>>(
89 "Models", _options.get<CrossRef<T>>(input_option_name).raw());
90 _nl_params[name] = &nl_param.param();
91 return nl_param.param().value();
92 }
93 catch (const NEMLException & e2)
94 {
95 std::cerr << e1.what() << std::endl;
96 std::cerr << e2.what() << std::endl;
97 }
98 }
99 }
100
101 throw NEMLException(
102 "Trying to register parameter named " + name + " from input option named " +
103 input_option_name + " of type " + utils::demangle(typeid(T).name()) +
104 ". Make sure you provided the correct parameter name, option name, and parameter type. Note "
105 "that the parameter type can either be a plain type, a cross-reference, or a nonlinear "
106 "parameter.");
107}
108
109#define PARAMETERSTORE_INTANTIATE_FIXEDDIMTENSOR(T) \
110 template const T & ParameterStore::declare_parameter<T>(const std::string &, const std::string &)
111FOR_ALL_FIXEDDIMTENSOR(PARAMETERSTORE_INTANTIATE_FIXEDDIMTENSOR);
112} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The base class for all constitutive models.
Definition Model.h:53
The base class of all "manufacturable" objects in the NEML2 library.
Definition NEML2Object.h:39
const T * host() const
Get a readonly pointer to the host.
Definition NEML2Object.h:91
Definition error.h:33
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:59
virtual void send_parameters_to(const torch::TensorOptions &options)
Send parameters to options.
Definition ParameterStore.cxx:47
ParameterStore(const OptionSet &options, NEML2Object *object)
Definition ParameterStore.cxx:32
const Storage< std::string, TensorValueBase > & named_parameters() const
Definition ParameterStore.h:44
const T & declare_parameter(const std::string &name, const T &rawval)
Declare a parameter.
Definition ParameterStore.h:145
const VariableBase * nl_param(const std::string &) const
Query the existence of a nonlinear parameter.
Definition ParameterStore.cxx:56
Definition Variable.h:41
std::string demangle(const char *name)
Definition parser_utils.cxx:46
Definition CrossRef.cxx:32
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73