NEML2 1.4.0
Loading...
Searching...
No Matches
TransientDriver.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/drivers/Driver.h"
28#include "neml2/tensors/tensors.h"
29
30#include <filesystem>
31#include <torch/nn/modules/container/moduledict.h>
32#include <torch/serialize.h>
33
34namespace neml2
35{
40class TransientDriver : public Driver
41{
42public:
44
50 TransientDriver(const OptionSet & options);
51
52 virtual void diagnose(std::vector<Diagnosis> &) const override;
53
54 bool run() override;
55
56 const Model & model() const { return _model; }
57
59 virtual std::string save_as_path() const;
60
67 virtual torch::nn::ModuleDict result() const;
68
69protected:
71 virtual bool solve();
72
73 // @{ Routines that are called every step
75 virtual void advance_step();
77 virtual void update_forces();
79 virtual void apply_ic();
81 virtual void apply_predictor();
83 virtual void solve_step();
85 virtual void store_input();
87 virtual void store_output();
88 // @}
89
91 virtual void output() const;
92
94 const bool _enable_AD;
98 const torch::Device _device;
99
114
116 std::string _predictor;
118 std::string _save_as;
120 const bool _show_params;
122 const bool _show_input;
124 const bool _show_output;
125
130
132 std::vector<VariableName> _ic_scalar_names;
134 std::vector<CrossRef<Scalar>> _ic_scalar_values;
136 std::vector<VariableName> _ic_rot_names;
138 std::vector<CrossRef<Rot>> _ic_rot_values;
140 std::vector<VariableName> _ic_sr2_names;
142 std::vector<CrossRef<SR2>> _ic_sr2_values;
143
146
147private:
148 void output_pt(const std::filesystem::path & out) const;
149
150 template <typename T>
151 void set_IC(const std::vector<VariableName> & ic_names,
152 const std::vector<CrossRef<T>> & ic_values)
153 {
154 for (size_t i = 0; i < ic_names.size(); i++)
156 }
157};
158} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
The Driver drives the execution of a NEML2 Model.
Definition Driver.h:46
The accessor containing all the information needed to access an item in a LabeledAxis.
Definition LabeledAxisAccessor.h:47
void base_index_put_(indexing::TensorLabelsRef labels, const Tensor &other)
Set values by slicing on the base dimensions.
Definition LabeledTensor.cxx:303
A single-batched, logically 1D LabeledTensor.
Definition LabeledVector.h:38
The base class for all constitutive models.
Definition Model.h:55
A custom map-like data structure. The keys are strings, and the values can be nonhomogeneously typed.
Definition OptionSet.h:100
The (logical) scalar.
Definition Scalar.h:38
The driver for a transient initial-value problem.
Definition TransientDriver.h:41
virtual bool solve()
Solve the initial value problem.
Definition TransientDriver.cxx:173
const torch::Device _device
The device on which to evaluate the model.
Definition TransientDriver.h:98
virtual void advance_step()
Advance in time: the state becomes old state, and forces become old forces.
Definition TransientDriver.cxx:208
const Model & model() const
Definition TransientDriver.h:56
LabeledVector _result_in
Inputs from all time steps.
Definition TransientDriver.h:127
const bool _show_input
Set to true to show model's input axis at the beginning.
Definition TransientDriver.h:122
Model & _model
The model which the driver uses to perform constitutive updates.
Definition TransientDriver.h:96
const bool _show_output
Set to true to show model's output axis at the beginning.
Definition TransientDriver.h:124
virtual void diagnose(std::vector< Diagnosis > &) const override
Check for common problems.
Definition TransientDriver.cxx:139
std::vector< VariableName > _ic_rot_names
Names for the Rot initial conditions.
Definition TransientDriver.h:136
TransientDriver(const OptionSet &options)
Construct a new TransientDriver object.
Definition TransientDriver.cxx:104
virtual void update_forces()
Update the driving forces for the current time step.
Definition TransientDriver.cxx:218
Size _nsteps
Total number of steps.
Definition TransientDriver.h:107
LabeledVector _result_out
Outputs from all time steps.
Definition TransientDriver.h:129
bool run() override
Let the driver run, return true upon successful completion, and return false otherwise.
Definition TransientDriver.cxx:150
const bool _show_params
Set to true to list all the model parameters at the beginning.
Definition TransientDriver.h:120
Scalar _time
The current time.
Definition TransientDriver.h:101
virtual std::string save_as_path() const
The destination file/path to save the results.
Definition TransientDriver.cxx:305
virtual void apply_predictor()
Apply the predictor to calculate the initial guess for the current time step.
Definition TransientDriver.cxx:236
virtual void store_input()
Save the input of the current time step.
Definition TransientDriver.cxx:293
std::vector< CrossRef< SR2 > > _ic_sr2_values
Values for the SR2 initial conditions.
Definition TransientDriver.h:142
virtual torch::nn::ModuleDict result() const
The results (input and output) from all time steps.
Definition TransientDriver.cxx:311
std::vector< CrossRef< Scalar > > _ic_scalar_values
Values for the scalar initial conditions.
Definition TransientDriver.h:134
std::vector< CrossRef< Rot > > _ic_rot_values
Values for the Rot initial conditions.
Definition TransientDriver.h:138
Size _step_count
The current step count.
Definition TransientDriver.h:103
std::vector< VariableName > _ic_sr2_names
Names for the SR2 initial conditions.
Definition TransientDriver.h:140
std::string _save_as
The destination file name or file path.
Definition TransientDriver.h:118
LabeledVector & _in
The input to the constitutive model.
Definition TransientDriver.h:111
Size _nbatch
The batch size.
Definition TransientDriver.h:109
std::vector< VariableName > _ic_scalar_names
Names for scalar initial conditions.
Definition TransientDriver.h:132
LabeledVector & _out
The output of the constitutive model.
Definition TransientDriver.h:113
virtual void output() const
Save the results into the destination file/path.
Definition TransientDriver.cxx:333
const bool _enable_AD
Whether to disable automatic differentiation.
Definition TransientDriver.h:94
VariableName _time_name
VariableName for the time.
Definition TransientDriver.h:105
static OptionSet expected_options()
Definition TransientDriver.cxx:34
std::string _predictor
The predictor used to set the initial guess.
Definition TransientDriver.h:116
virtual void store_output()
Save the output of the current time step.
Definition TransientDriver.cxx:299
virtual void apply_ic()
Apply the initial conditions.
Definition TransientDriver.cxx:228
Real _cp_elastic_scale
Scale value for initial cp predictor.
Definition TransientDriver.h:145
virtual void solve_step()
Perform the constitutive update for the current time step.
Definition TransientDriver.cxx:287
Definition CrossRef.cxx:30