NEML2 1.4.0
Loading...
Searching...
No Matches
ParameterStore Class Reference

Interface for object which can store parameters. More...

Detailed Description

Interface for object which can store parameters.

#include <ParameterStore.h>

Inheritance diagram for ParameterStore:

Public Member Functions

 ParameterStore (const OptionSet &options, NEML2Object *object)
 
std::map< std::string, const VariableBase * > _nl_params
 Map from nonlinear parameter names to their corresponding variable views.
 
std::map< std::string, Model * > _nl_param_models
 Map from nonlinear parameter names to models which evaluate them.
 
const Storage< std::string, TensorValueBase > & named_parameters () const
 
Storage< std::string, TensorValueBase > & named_parameters ()
 
void set_parameter (const std::string &, const Tensor &)
 }@
 
void set_parameters (const std::map< std::string, Tensor > &)
 Set values for parameters.
 
TensorValueBaseget_parameter (const std::string &name)
 Get a writable reference of a parameter.
 
const TensorValueBaseget_parameter (const std::string &name) const
 Get a read-only reference of a parameter.
 
bool has_nl_param () const
 Whether this parameter store has any nonlinear parameter.
 
const VariableBasenl_param (const std::string &) const
 Query the existence of a nonlinear parameter.
 
virtual std::map< std::string, const VariableBase * > named_nonlinear_parameters (bool recursive=false) const
 Get all nonlinear parameters.
 
virtual std::map< std::string, Model * > named_nonlinear_parameter_models (bool recursive=false) const
 Get all nonlinear parameters' models.
 
virtual void send_parameters_to (const torch::TensorOptions &options)
 Send parameters to options.
 
template<typename T , typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
const T & declare_parameter (const std::string &name, const T &rawval)
 Declare a parameter.
 
template<typename T , typename = typename std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
const T & declare_parameter (const std::string &name, const std::string &input_option_name, bool allow_nonlinear=false)
 Declare a parameter.
 

Constructor & Destructor Documentation

◆ ParameterStore()

ParameterStore ( const OptionSet & options,
NEML2Object * object )

Member Function Documentation

◆ declare_parameter() [1/2]

template<typename T , typename >
const T & declare_parameter ( const std::string & name,
const std::string & input_option_name,
bool allow_nonlinear = false )
protected

Declare a parameter.

Note that all parameters are stored in the host (the object exposed to users). An object may be used multiple times in the host, and the same parameter may be declared multiple times. That is allowed, but only the first call to declare_parameter constructs the parameter value, and subsequent calls only returns a reference to the existing parameter.

Template Parameters
TParameter type. See Statically shaped tensors for supported types.
Parameters
nameName of the model parameter.
input_option_nameName of the input option that defines the value of the model parameter.
allow_nonlinearWhether allows coupling with a nonlinear parameter
Returns
T The value of the registered model parameter.

◆ declare_parameter() [2/2]

template<typename T , typename >
const T & declare_parameter ( const std::string & name,
const T & rawval )
protected

Declare a parameter.

Note that all parameters are stored in the host (the object exposed to users). An object may be used multiple times in the host, and the same parameter may be declared multiple times. That is allowed, but only the first call to declare_parameter constructs the parameter value, and subsequent calls only returns a reference to the existing parameter.

Template Parameters
TBuffer type. See Statically shaped tensors for supported types.
Parameters
nameBuffer name
rawvalBuffer value
Returns
Reference to buffer

◆ get_parameter() [1/2]

TensorValueBase & get_parameter ( const std::string & name)

Get a writable reference of a parameter.

◆ get_parameter() [2/2]

const TensorValueBase & get_parameter ( const std::string & name) const

Get a read-only reference of a parameter.

◆ has_nl_param()

bool has_nl_param ( ) const
inline

Whether this parameter store has any nonlinear parameter.

◆ named_nonlinear_parameter_models()

std::map< std::string, Model * > named_nonlinear_parameter_models ( bool recursive = false) const
virtual

Get all nonlinear parameters' models.

Reimplemented in ComposedModel.

◆ named_nonlinear_parameters()

std::map< std::string, const VariableBase * > named_nonlinear_parameters ( bool recursive = false) const
virtual

Get all nonlinear parameters.

Reimplemented in ComposedModel.

◆ named_parameters() [1/2]

Storage< std::string, TensorValueBase > & named_parameters ( )

◆ named_parameters() [2/2]

const Storage< std::string, TensorValueBase > & named_parameters ( ) const
inline
Returns
the buffer storage

◆ nl_param()

const VariableBase * nl_param ( const std::string & name) const

Query the existence of a nonlinear parameter.

Returns
const VariableBase* Pointer to the VariableBase if the parameter associated with the given parameter name is nonlinear. Returns nullptr otherwise.

◆ send_parameters_to()

void send_parameters_to ( const torch::TensorOptions & options)
protectedvirtual

Send parameters to options.

Parameters
optionsThe target options

◆ set_parameter()

void set_parameter ( const std::string & name,
const Tensor & value )

}@

Set the value for a parameter

◆ set_parameters()

void set_parameters ( const std::map< std::string, Tensor > & param_values)

Set values for parameters.

Member Data Documentation

◆ _nl_param_models

std::map<std::string, Model *> _nl_param_models
protected

Map from nonlinear parameter names to models which evaluate them.

◆ _nl_params

std::map<std::string, const VariableBase *> _nl_params
protected

Map from nonlinear parameter names to their corresponding variable views.