NEML2 1.4.0
Loading...
Searching...
No Matches
FixedDimTensor.h
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/tensors/BatchTensor.h"
28
29namespace neml2
30{
36template <class Derived, TorchSize... S>
37class FixedDimTensor : public BatchTensorBase<Derived>
38{
39public:
41 static inline const TorchShape const_base_sizes = {S...};
42
44 static constexpr TorchSize const_base_dim = sizeof...(S);
45
48
50 FixedDimTensor() = default;
51
53 explicit FixedDimTensor(const torch::Tensor & tensor, TorchSize batch_dim);
54
56 FixedDimTensor(const torch::Tensor & tensor);
57
59 operator BatchTensor() const;
60
62 [[nodiscard]] static Derived
63 empty(const torch::TensorOptions & options = default_tensor_options());
65 [[nodiscard]] static Derived
66 empty(TorchShapeRef batch_shape, const torch::TensorOptions & options = default_tensor_options());
68 [[nodiscard]] static Derived
69 zeros(const torch::TensorOptions & options = default_tensor_options());
71 [[nodiscard]] static Derived
72 zeros(TorchShapeRef batch_shape, const torch::TensorOptions & options = default_tensor_options());
74 [[nodiscard]] static Derived
75 ones(const torch::TensorOptions & options = default_tensor_options());
77 [[nodiscard]] static Derived
78 ones(TorchShapeRef batch_shape, const torch::TensorOptions & options = default_tensor_options());
80 [[nodiscard]] static Derived
81 full(Real init, const torch::TensorOptions & options = default_tensor_options());
83 [[nodiscard]] static Derived
85 Real init,
86 const torch::TensorOptions & options = default_tensor_options());
87
89 [[nodiscard]] static BatchTensor identity_map(const torch::TensorOptions &)
90 {
91 throw NEMLException("Not implemented");
92 }
93};
94
96// Implementations
98
99template <class Derived, TorchSize... S>
100FixedDimTensor<Derived, S...>::FixedDimTensor(const torch::Tensor & tensor, TorchSize batch_dim)
101 : BatchTensorBase<Derived>(tensor, batch_dim)
102{
103 neml_assert_dbg(this->base_sizes() == const_base_sizes,
104 "Base shape mismatch: trying to create a tensor with base shape ",
106 " from a tensor with base shape ",
107 this->base_sizes());
108}
109
110template <class Derived, TorchSize... S>
112 : BatchTensorBase<Derived>(tensor, tensor.dim() - const_base_dim)
113{
114 neml_assert_dbg(this->base_sizes() == const_base_sizes,
115 "Base shape mismatch: trying to create a tensor with base shape ",
117 " from a tensor with shape ",
118 tensor.sizes());
119}
120
121template <class Derived, TorchSize... S>
122FixedDimTensor<Derived, S...>::operator BatchTensor() const
123{
124 return BatchTensor(*this, this->batch_dim());
125}
126
127template <class Derived, TorchSize... S>
129FixedDimTensor<Derived, S...>::empty(const torch::TensorOptions & options)
130{
131 return Derived(torch::empty(const_base_sizes, options), 0);
132}
133
134template <class Derived, TorchSize... S>
137 const torch::TensorOptions & options)
138{
139 return Derived(torch::empty(utils::add_shapes(batch_shape, const_base_sizes), options),
140 batch_shape.size());
141}
142
143template <class Derived, TorchSize... S>
145FixedDimTensor<Derived, S...>::zeros(const torch::TensorOptions & options)
146{
147 return Derived(torch::zeros(const_base_sizes, options), 0);
148}
149
150template <class Derived, TorchSize... S>
153 const torch::TensorOptions & options)
154{
155 return Derived(torch::zeros(utils::add_shapes(batch_shape, const_base_sizes), options),
156 batch_shape.size());
157}
158
159template <class Derived, TorchSize... S>
161FixedDimTensor<Derived, S...>::ones(const torch::TensorOptions & options)
162{
163 return Derived(torch::ones(const_base_sizes, options), 0);
164}
165
166template <class Derived, TorchSize... S>
169{
170 return Derived(torch::ones(utils::add_shapes(batch_shape, const_base_sizes), options),
171 batch_shape.size());
172}
173
174template <class Derived, TorchSize... S>
176FixedDimTensor<Derived, S...>::full(Real init, const torch::TensorOptions & options)
177{
178 return Derived(torch::full(const_base_sizes, init, options), 0);
179}
180
181template <class Derived, TorchSize... S>
184 Real init,
185 const torch::TensorOptions & options)
186{
187 return Derived(torch::full(utils::add_shapes(batch_shape, const_base_sizes), init, options),
188 batch_shape.size());
189}
190} // namespace neml2
NEML2's enhanced tensor type.
Definition BatchTensorBase.h:46
TorchSize batch_dim() const
Return the number of batch dimensions.
Definition BatchTensorBase.cxx:128
TorchShapeRef base_sizes() const
Return the base size.
Definition BatchTensorBase.cxx:163
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
CrossRef()=default
FixedDimTensor inherits from BatchTensorBase and additionally templates on the base shape.
Definition FixedDimTensor.h:38
static Derived full(Real init, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition FixedDimTensor.h:176
static Derived full(TorchShapeRef batch_shape, Real init, const torch::TensorOptions &options=default_tensor_options())
Full tensor given batch shape.
Definition FixedDimTensor.h:183
static BatchTensor identity_map(const torch::TensorOptions &)
Derived tensor classes should define identity_map where appropriate.
Definition FixedDimTensor.h:89
static constexpr TorchSize const_base_dim
The base dim.
Definition FixedDimTensor.h:44
static Derived ones(TorchShapeRef batch_shape, const torch::TensorOptions &options=default_tensor_options())
Unit tensor given batch shape.
Definition FixedDimTensor.h:168
FixedDimTensor()=default
Default constructor.
static const TorchSize const_base_storage
The base storage.
Definition FixedDimTensor.h:47
static Derived empty(const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor.
Definition FixedDimTensor.h:129
FixedDimTensor(const torch::Tensor &tensor)
Construct from another torch::Tensor and infer batch dimension.
Definition FixedDimTensor.h:111
FixedDimTensor(const torch::Tensor &tensor, TorchSize batch_dim)
Construct from another torch::Tensor given batch dimension.
Definition FixedDimTensor.h:100
static Derived zeros(const torch::TensorOptions &options=default_tensor_options())
Unbatched zero tensor.
Definition FixedDimTensor.h:145
static Derived ones(const torch::TensorOptions &options=default_tensor_options())
Unbatched unit tensor.
Definition FixedDimTensor.h:161
static Derived zeros(TorchShapeRef batch_shape, const torch::TensorOptions &options=default_tensor_options())
Zero tensor given batch shape.
Definition FixedDimTensor.h:152
static const TorchShape const_base_sizes
The base shape.
Definition FixedDimTensor.h:41
static Derived empty(TorchShapeRef batch_shape, const torch::TensorOptions &options=default_tensor_options())
Empty tensor given batch shape.
Definition FixedDimTensor.h:136
Definition error.h:33
TorchSize storage_size(TorchShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:32
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
const torch::TensorOptions default_tensor_options()
Definition types.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
int64_t TorchSize
Definition types.h:33
std::vector< TorchSize > TorchShape
Definition types.h:34
double Real
Definition types.h:31
torch::IntArrayRef TorchShapeRef
Definition types.h:35