NEML2 1.4.0
Loading...
Searching...
No Matches
ParameterStore Class Reference

Interface for object which can store parameters. More...

Detailed Description

Interface for object which can store parameters.

#include <ParameterStore.h>

Inheritance diagram for ParameterStore:

Public Member Functions

 ParameterStore (const OptionSet &options, NEML2Object *object)
 
const Storage< std::string, TensorValueBase > & named_parameters () const
 
Storage< std::string, TensorValueBase > & named_parameters ()
 
template<typename T , typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<T>, T>>>
T & get_parameter (const std::string &name)
 }@
 
bool has_nl_param () const
 Whether this parameter store has any nonlinear parameter.
 
const std::map< std::string, const VariableBase * > & nl_params () const
 Get all nonlinear parameters.
 
const VariableBasenl_param (const std::string &) const
 Query the existence of a nonlinear parameter.
 
virtual void send_parameters_to (const torch::TensorOptions &options)
 Send parameters to options.
 
template<typename T , typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<T>, T>>>
const T & declare_parameter (const std::string &name, const T &rawval)
 Declare a parameter.
 
template<typename T , typename = typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<T>, T>>>
const T & declare_parameter (const std::string &name, const std::string &input_option_name)
 Declare a parameter.
 

Constructor & Destructor Documentation

◆ ParameterStore()

ParameterStore ( const OptionSet & options,
NEML2Object * object )

Member Function Documentation

◆ declare_parameter() [1/2]

template<typename T , typename >
const T & declare_parameter ( const std::string & name,
const std::string & input_option_name )
protected

Declare a parameter.

Note that all parameters are stored in the host (the object exposed to users). An object may be used multiple times in the host, and the same parameter may be declared multiple times. That is allowed, but only the first call to declare_parameter constructs the parameter value, and subsequent calls only returns a reference to the existing parameter.

Template Parameters
TParameter type. See Statically shaped tensors for supported types.
Parameters
nameName of the model parameter.
input_option_nameName of the input option that defines the value of the model parameter.
Returns
T The value of the registered model parameter.

◆ declare_parameter() [2/2]

template<typename T , typename >
const T & declare_parameter ( const std::string & name,
const T & rawval )
protected

Declare a parameter.

Note that all parameters are stored in the host (the object exposed to users). An object may be used multiple times in the host, and the same parameter may be declared multiple times. That is allowed, but only the first call to declare_parameter constructs the parameter value, and subsequent calls only returns a reference to the existing parameter.

Template Parameters
TBuffer type. See Statically shaped tensors for supported types.
Parameters
nameBuffer name
rawvalBuffer value
Returns
Reference to buffer

◆ get_parameter()

template<typename T , typename >
T & get_parameter ( const std::string & name)

}@

Get a writable reference of a parameter

◆ has_nl_param()

bool has_nl_param ( ) const
inline

Whether this parameter store has any nonlinear parameter.

◆ named_parameters() [1/2]

Storage< std::string, TensorValueBase > & named_parameters ( )

◆ named_parameters() [2/2]

const Storage< std::string, TensorValueBase > & named_parameters ( ) const
inline
Returns
the buffer storage

◆ nl_param()

const VariableBase * nl_param ( const std::string & name) const

Query the existence of a nonlinear parameter.

Returns
const VariableBase* Pointer to the VariableBase if the parameter associated with the given parameter name is nonlinear. Returns nullptr otherwise.

◆ nl_params()

const std::map< std::string, const VariableBase * > & nl_params ( ) const
inline

Get all nonlinear parameters.

◆ send_parameters_to()

void send_parameters_to ( const torch::TensorOptions & options)
protectedvirtual

Send parameters to options.

Parameters
optionsThe target options