NEML2 1.4.0
Loading...
Searching...
No Matches
TensorValue.h
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24#pragma once
25
26#include "neml2/tensors/BatchTensor.h"
27#include "neml2/misc/parser_utils.h"
28
29namespace neml2
30{
37{
38public:
39 virtual ~TensorValueBase() = default;
40
42 virtual void to(const torch::TensorOptions &) = 0;
43
45 virtual operator BatchTensor() const = 0;
46
48 virtual void set(const BatchTensor & val) = 0;
49};
50
52template <typename T>
54{
55public:
56 TensorValue() = default;
57
58 TensorValue(const T & value)
59 : _value(value)
60 {
61 }
62
63 virtual void to(const torch::TensorOptions & options) override { _value = _value.to(options); }
64
65 virtual operator BatchTensor() const override { return _value; }
66
67 template <typename T2 = T, typename = typename std::enable_if_t<!std::is_same_v<T2, BatchTensor>>>
68 operator T() const
69 {
70 return _value;
71 }
72
73 T & value() { return _value; }
74
75 virtual void set(const BatchTensor & val) override { _value = val; }
76
77private:
78 T _value;
79};
80} // namespace neml2
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
The base class to allow us to set up a polymorphic container of BatchTensors. The concrete definition...
Definition TensorValue.h:37
virtual void set(const BatchTensor &val)=0
Set the parameter value.
virtual void to(const torch::TensorOptions &)=0
Send the value to the target options.
virtual ~TensorValueBase()=default
Concrete definition of tensor value.
Definition TensorValue.h:54
virtual void set(const BatchTensor &val) override
Set the parameter value.
Definition TensorValue.h:75
TensorValue()=default
TensorValue(const T &value)
Definition TensorValue.h:58
T & value()
Definition TensorValue.h:73
virtual void to(const torch::TensorOptions &options) override
Send the value to the target options.
Definition TensorValue.h:63
Definition CrossRef.cxx:32