NEML2
1.4.0
Loading...
Searching...
No Matches
CrossRef.cxx
1
// Copyright 2023, UChicago Argonne, LLC
2
// All Rights Reserved
3
// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4
// By: Argonne National Laboratory
5
// OPEN SOURCE LICENSE (MIT)
6
//
7
// Permission is hereby granted, free of charge, to any person obtaining a copy
8
// of this software and associated documentation files (the "Software"), to deal
9
// in the Software without restriction, including without limitation the rights
10
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
// copies of the Software, and to permit persons to whom the Software is
12
// furnished to do so, subject to the following conditions:
13
//
14
// The above copyright notice and this permission notice shall be included in
15
// all copies or substantial portions of the Software.
16
//
17
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23
// THE SOFTWARE.
24
25
#include "neml2/base/CrossRef.h"
26
#include "neml2/base/Factory.h"
27
#include "neml2/misc/parser_utils.h"
28
#include "neml2/tensors/tensors.h"
29
#include "neml2/tensors/macros.h"
30
31
namespace
neml2
32
{
33
template
<>
34
CrossRef<torch::Tensor>::operator
torch::Tensor()
const
35
{
36
try
37
{
38
// If it is just a number, we can still create a tensor out of it
39
return
torch::tensor(utils::parse<Real>(_raw_str),
default_tensor_options
());
40
}
41
catch
(
const
ParserException
&
e
)
42
{
43
// Conversion to a number failed, so it might be the name of another tensor
44
return
Factory::get_object<torch::Tensor>(
"Tensors"
, _raw_str);
45
}
46
}
47
48
template
<>
49
CrossRef<BatchTensor>::operator
BatchTensor
()
const
50
{
51
try
52
{
53
// If it is just a number, we can still create a Scalar out of it
54
return
BatchTensor::full
({}, {}, utils::parse<Real>(_raw_str),
default_tensor_options
());
55
}
56
catch
(
const
ParserException
&
e
)
57
{
58
// Conversion to a number failed, so it might be the name of another BatchTensor
59
return
Factory::get_object<BatchTensor>(
"Tensors"
, _raw_str);
60
}
61
}
62
63
template
class
CrossRef<torch::Tensor>
;
64
#define CROSSREF_SPECIALIZE_FIXEDDIMTENSOR(T) \
65
template <> \
66
CrossRef<T>::operator T() const \
67
{ \
68
try \
69
{ \
70
return T::full(utils::parse<Real>(_raw_str)); \
71
} \
72
catch (const ParserException & e) \
73
{ \
74
return Factory::get_object<T>("Tensors", _raw_str); \
75
} \
76
} \
77
static_assert(true)
78
79
FOR_ALL_FIXEDDIMTENSOR(CROSSREF_SPECIALIZE_FIXEDDIMTENSOR);
80
}
// namesace neml2
neml2::BatchTensor
Definition
BatchTensor.h:32
neml2::BatchTensor::full
static BatchTensor full(const TorchShapeRef &base_shape, Real init, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition
BatchTensor.cxx:75
neml2::CrossRef
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition
CrossRef.h:52
neml2::ParserException
Definition
parser_utils.h:35
neml2
Definition
CrossRef.cxx:32
neml2::default_tensor_options
const torch::TensorOptions default_tensor_options()
Definition
types.cxx:30
src
neml2
base
CrossRef.cxx
Generated by
1.10.0