25#include "neml2/drivers/TransientDriver.h"
26#include "neml2/models/ComposedModel.h"
27#include "neml2/models/ImplicitUpdate.h"
29namespace fs = std::filesystem;
38 options.
set<std::string>(
"model");
39 options.
set(
"model").doc() =
"The material model to be updated by the driver";
42 options.
set(
"times").doc() =
43 "Time steps to perform the material update. The times tensor must have exactly 2 dimensions. "
44 "The first dimension represents time steps, and the second dimension represents batches "
45 "(i.e., how many material models to update simultaneously).";
48 options.
set(
"time").doc() =
"Time";
50 options.
set<std::string>(
"predictor") =
"PREVIOUS_STATE";
51 options.
set(
"predictor").doc() =
52 "Predictor used to set the initial guess for each time step. Options are PREVIOUS_STATE, "
53 "LINEAR_EXTRAPOLATION, CP_PREVIOUS_STATE, and CP_LINEAR_EXTRAPOLATION. The options prefixed "
54 "with 'CP_' are specifically designed for crystal plasticity models.";
56 options.
set<
Real>(
"cp_elastic_scale") = 1.0;
57 options.
set(
"cp_elastic_scale").doc() =
"Elastic step scale factor used in the 'CP_' predictors";
59 options.
set<std::string>(
"save_as");
60 options.
set(
"save_as").doc() =
61 "File path (absolute or relative to the working directory) to store the results";
63 options.
set<
bool>(
"show_parameters") =
false;
64 options.
set(
"show_parameters").doc() =
"Whether to show model parameters at the beginning";
66 options.
set<
bool>(
"show_input_axis") =
false;
67 options.
set(
"show_input_axis").doc() =
"Whether to show model input axis at the beginning";
69 options.
set<
bool>(
"show_output_axis") =
false;
70 options.
set(
"show_output_axis").doc() =
"Whether to show model output axis at the beginning";
72 options.
set<std::string>(
"device") =
"cpu";
73 options.
set(
"device").doc() =
74 "Device on which to evaluate the material model. The string supplied must follow the "
75 "following schema: (cpu|cuda)[:<device-index>] where cpu or cuda specifies the device type, "
76 "and :<device-index> optionally specifies a device index. For example, device='cpu' sets the "
77 "target compute device to be CPU, and device='cuda:1' sets the target compute device to be "
78 "CUDA with device ID 1.";
80 options.
set<std::vector<VariableName>>(
"ic_scalar_names");
81 options.
set(
"ic_scalar_names").doc() =
"Apply initial conditions to these Scalar variables";
83 options.
set<std::vector<CrossRef<Scalar>>>(
"ic_scalar_values");
84 options.
set(
"ic_scalar_values").doc() =
"Initial condition values for the Scalar variables";
86 options.
set<std::vector<VariableName>>(
"ic_rot_names");
87 options.
set(
"ic_rot_names").doc() =
"Apply initial conditions to these Rot variables";
89 options.
set<std::vector<CrossRef<Rot>>>(
"ic_rot_values");
90 options.
set(
"ic_rot_values").doc() =
"Initial condition values for the Rot variables";
92 options.
set<std::vector<VariableName>>(
"ic_sr2_names");
93 options.
set(
"ic_sr2_names").doc() =
"Apply initial conditions to these SR2 variables";
95 options.
set<std::vector<CrossRef<SR2>>>(
"ic_sr2_values");
96 options.
set(
"ic_sr2_values").doc() =
"Initial condition values for the SR2 variables";
104 _device(options.get<std::
string>(
"device")),
105 _time(options.get<
CrossRef<torch::Tensor>>(
"times"), 2),
108 _nsteps(_time.batch_sizes()[0]),
109 _nbatch(_time.batch_sizes()[1]),
110 _in(_model.input_storage()),
111 _out(_model.output_storage()),
112 _predictor(options.get<std::
string>(
"predictor")),
113 _save_as(options.get<std::
string>(
"save_as")),
114 _show_params(options.get<
bool>(
"show_parameters")),
115 _show_input(options.get<
bool>(
"show_input_axis")),
116 _show_output(options.get<
bool>(
"show_output_axis")),
119 _ic_scalar_names(options.get<std::vector<VariableName>>(
"ic_scalar_names")),
120 _ic_scalar_values(options.get<std::vector<CrossRef<Scalar>>>(
"ic_scalar_values")),
121 _ic_rot_names(options.get<std::vector<VariableName>>(
"ic_rot_names")),
122 _ic_rot_values(options.get<std::vector<CrossRef<Rot>>>(
"ic_rot_values")),
123 _ic_sr2_names(options.get<std::vector<VariableName>>(
"ic_sr2_names")),
124 _ic_sr2_values(options.get<std::vector<CrossRef<SR2>>>(
"ic_sr2_values")),
125 _cp_elastic_scale(options.get<
Real>(
"cp_elastic_scale"))
127 _model.reinit({_nbatch}, 0, _device);
129 _time = _time.to(_device);
130 _result_in = _result_in.to(_device);
131 _result_out = _result_out.to(_device);
145 "Input time should have dimension 2 but instead has dimension ",
155 std::cout <<
pname << std::endl;
200 std::cout << std::endl;
210 if (
_in.
axis(0).has_subaxis(
"old_state") &&
_out.
axis(0).has_subaxis(
"state"))
213 if (
_in.
axis(0).has_subaxis(
"old_forces") &&
_in.
axis(0).has_subaxis(
"forces"))
243 if (
_in.
axis(0).has_subaxis(
"state") &&
_in.
axis(0).has_subaxis(
"old_state"))
247 else if (
predictor ==
"LINEAR_EXTRAPOLATION")
274 SR2 D =
_in.
get<
SR2>(std::vector<std::string>{
"forces",
"deformation_rate"});
312 auto res_in = std::make_shared<torch::nn::Module>();
317 auto res_out = std::make_shared<torch::nn::Module>();
322 torch::nn::ModuleDict
res;
332 std::cout <<
"Saving results..." << std::endl;
335 auto cwd = fs::current_path();
338 if (
out.extension() ==
".pt")
342 neml_assert(
false,
"Unsupported output format: ",
out.extension());
347 std::cout <<
"Results saved to " <<
save_as_path() << std::endl;
352TransientDriver::output_pt(
const std::filesystem::path &
out)
const
Derived batch_index(TorchSlice indices) const
Get a batch.
Definition BatchTensorBase.cxx:184
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The Driver drives the execution of a NEML2 Model.
Definition Driver.h:40
bool _verbose
Whether to print out additional (debugging) information during the execution.
Definition Driver.h:59
virtual void check_integrity() const
Check the integrity of the set up.
Definition Driver.h:56
static OptionSet expected_options()
Definition Driver.cxx:30
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:44
const LabeledAxis & axis(TorchSize i=0) const
Get a specific labeled axis.
Definition LabeledTensor.h:130
static LabeledVector zeros(TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new storage with zeros.
Definition LabeledTensor.cxx:109
Derived to(const torch::TensorOptions &options) const
Change tensor options.
Definition LabeledTensor.cxx:242
void set(const BatchTensorBase< T > &value, S &&... names)
Set and interpret the input as an object.
Definition LabeledTensor.h:193
void batch_index_put(TorchSlice indices, const torch::Tensor &other)
Set a index sliced on the batch dimensions to a value.
Definition LabeledTensor.cxx:214
variable_type< T >::type get(S &&... names) const
Get and interpret the view as an object.
Definition LabeledTensor.h:176
A single-batched, logically 1D LabeledTensor.
Definition LabeledVector.h:38
void fill(const LabeledVector &other, bool recursive=true)
Definition LabeledVector.cxx:45
LabeledVector slice(const std::string &name) const
Slice the logically 1D tensor by a single sub-axis.
Definition LabeledVector.cxx:31
The base class for all constitutive models.
Definition Model.h:53
virtual LabeledVector value(const LabeledVector &in)
Convenient shortcut to construct and return the model value.
Definition Model.cxx:348
virtual std::vector< Diagnosis > preflight() const
Check for common problems.
Definition Model.cxx:70
const std::string & name() const
A readonly reference to the object's name.
Definition NEML2Object.h:65
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:59
T & set(const std::string &)
Definition OptionSet.h:436
const Storage< std::string, TensorValueBase > & named_parameters() const
Definition ParameterStore.h:44
The (logical) symmetric second order tensor.
Definition SR2.h:46
The (logical) scalar.
Definition Scalar.h:38
virtual bool solve()
Solve the initial value problem.
Definition TransientDriver.cxx:173
virtual void advance_step()
Advance in time: the state becomes old state, and forces become old forces.
Definition TransientDriver.cxx:208
LabeledVector _result_in
Inputs from all time steps.
Definition TransientDriver.h:123
const bool _show_input
Set to true to show model's input axis at the beginning.
Definition TransientDriver.h:118
Model & _model
The model which the driver uses to perform constitutive updates.
Definition TransientDriver.h:92
const bool _show_output
Set to true to show model's output axis at the beginning.
Definition TransientDriver.h:120
std::vector< VariableName > _ic_rot_names
Names for the Rot initial conditions.
Definition TransientDriver.h:132
TransientDriver(const OptionSet &options)
Construct a new TransientDriver object.
Definition TransientDriver.cxx:101
virtual void update_forces()
Update the driving forces for the current time step.
Definition TransientDriver.cxx:218
LabeledVector _result_out
Outputs from all time steps.
Definition TransientDriver.h:125
bool run() override
Let the driver run, return true upon successful completion, and return false otherwise.
Definition TransientDriver.cxx:150
const bool _show_params
Set to true to list all the model parameters at the beginning.
Definition TransientDriver.h:116
Scalar _time
The current time.
Definition TransientDriver.h:97
virtual std::string save_as_path() const
The destination file/path to save the results.
Definition TransientDriver.cxx:300
virtual void apply_predictor()
Apply the predictor to calculate the initial guess for the current time step.
Definition TransientDriver.cxx:233
virtual void store_input()
Save the input of the current time step.
Definition TransientDriver.cxx:288
std::vector< CrossRef< SR2 > > _ic_sr2_values
Values for the SR2 initial conditions.
Definition TransientDriver.h:138
virtual torch::nn::ModuleDict result() const
The results (input and output) from all time steps.
Definition TransientDriver.cxx:306
std::vector< CrossRef< Scalar > > _ic_scalar_values
Values for the scalar initial conditions.
Definition TransientDriver.h:130
std::vector< CrossRef< Rot > > _ic_rot_values
Values for the Rot initial conditions.
Definition TransientDriver.h:134
virtual void check_integrity() const override
Check the integrity of the set up.
Definition TransientDriver.cxx:135
std::vector< VariableName > _ic_sr2_names
Names for the SR2 initial conditions.
Definition TransientDriver.h:136
std::string _save_as
The destination file name or file path.
Definition TransientDriver.h:114
TorchSize _nbatch
The batch size.
Definition TransientDriver.h:105
LabeledVector & _in
The input to the constitutive model.
Definition TransientDriver.h:107
std::vector< VariableName > _ic_scalar_names
Names for scalar initial conditions.
Definition TransientDriver.h:128
LabeledVector & _out
The output of the constitutive model.
Definition TransientDriver.h:109
TorchSize _nsteps
Total number of steps.
Definition TransientDriver.h:103
virtual void output() const
Save the results into the destination file/path.
Definition TransientDriver.cxx:328
VariableName _time_name
VariableName for the time.
Definition TransientDriver.h:101
static OptionSet expected_options()
Definition TransientDriver.cxx:34
TorchSize _step_count
The current step count.
Definition TransientDriver.h:99
std::string _predictor
The predictor used to set the initial guess.
Definition TransientDriver.h:112
virtual void store_output()
Save the output of the current time step.
Definition TransientDriver.cxx:294
virtual void apply_ic()
Apply the initial conditions.
Definition TransientDriver.cxx:225
Real _cp_elastic_scale
Scale value for initial cp predictor.
Definition TransientDriver.h:141
virtual void solve_step()
Perform the constitutive update for the current time step.
Definition TransientDriver.cxx:282
LabeledAxis & output_axis()
Definition VariableStore.h:97
LabeledAxis & input_axis()
Definition VariableStore.h:91
std::string stringify(const T &t)
Definition utils.h:302
Definition CrossRef.cxx:32
double Real
Definition types.h:33
LabeledAxisAccessor VariableName
Definition Variable.h:35
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73