NEML2 1.4.0
|
Interface for object which can store parameters. More...
Interface for object which can store parameters.
#include <ParameterStore.h>
Public Member Functions | |
ParameterStore (const OptionSet &options, NEML2Object *object) | |
std::map< std::string, const VariableBase * > | _nl_params |
Map from nonlinear parameter names to their corresponding variable views. | |
std::map< std::string, Model * > | _nl_param_models |
Map from nonlinear parameter names to models which evaluate them. | |
const Storage< std::string, TensorValueBase > & | named_parameters () const |
Storage< std::string, TensorValueBase > & | named_parameters () |
void | set_parameter (const std::string &, const Tensor &) |
}@ | |
void | set_parameters (const std::map< std::string, Tensor > &) |
Set values for parameters. | |
TensorValueBase & | get_parameter (const std::string &name) |
Get a writable reference of a parameter. | |
const TensorValueBase & | get_parameter (const std::string &name) const |
Get a read-only reference of a parameter. | |
bool | has_nl_param () const |
Whether this parameter store has any nonlinear parameter. | |
const VariableBase * | nl_param (const std::string &) const |
Query the existence of a nonlinear parameter. | |
virtual std::map< std::string, const VariableBase * > | named_nonlinear_parameters (bool recursive=false) const |
Get all nonlinear parameters. | |
virtual std::map< std::string, Model * > | named_nonlinear_parameter_models (bool recursive=false) const |
Get all nonlinear parameters' models. | |
virtual void | send_parameters_to (const torch::TensorOptions &options) |
Send parameters to options. | |
template<typename T , typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>> | |
const T & | declare_parameter (const std::string &name, const T &rawval) |
Declare a parameter. | |
template<typename T , typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>> | |
const T & | declare_parameter (const std::string &name, const std::string &input_option_name, bool allow_nonlinear=false) |
Declare a parameter. | |
ParameterStore | ( | const OptionSet & | options, |
NEML2Object * | object ) |
|
protected |
Declare a parameter.
Note that all parameters are stored in the host (the object exposed to users). An object may be used multiple times in the host, and the same parameter may be declared multiple times. That is allowed, but only the first call to declare_parameter constructs the parameter value, and subsequent calls only returns a reference to the existing parameter.
T | Parameter type. See Statically shaped tensors for supported types. |
name | Name of the model parameter. |
input_option_name | Name of the input option that defines the value of the model parameter. |
allow_nonlinear | Whether allows coupling with a nonlinear parameter |
Declare a parameter.
Note that all parameters are stored in the host (the object exposed to users). An object may be used multiple times in the host, and the same parameter may be declared multiple times. That is allowed, but only the first call to declare_parameter constructs the parameter value, and subsequent calls only returns a reference to the existing parameter.
T | Buffer type. See Statically shaped tensors for supported types. |
name | Buffer name |
rawval | Buffer value |
TensorValueBase & get_parameter | ( | const std::string & | name | ) |
Get a writable reference of a parameter.
const TensorValueBase & get_parameter | ( | const std::string & | name | ) | const |
Get a read-only reference of a parameter.
|
inline |
Whether this parameter store has any nonlinear parameter.
|
virtual |
Get all nonlinear parameters' models.
Reimplemented in ComposedModel.
|
virtual |
Get all nonlinear parameters.
Reimplemented in ComposedModel.
Storage< std::string, TensorValueBase > & named_parameters | ( | ) |
|
inline |
const VariableBase * nl_param | ( | const std::string & | name | ) | const |
Query the existence of a nonlinear parameter.
Send parameters to options.
options | The target options |
}@
Set the value for a parameter
Set values for parameters.
|
protected |
Map from nonlinear parameter names to models which evaluate them.
|
protected |
Map from nonlinear parameter names to their corresponding variable views.