NEML2 1.4.0
Loading...
Searching...
No Matches
utils.h
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/misc/types.h"
28#include "neml2/misc/error.h"
29
30namespace neml2
31{
37template <class... T>
38bool broadcastable(T &&... tensors);
39
45template <class... T>
47
55template <class... T>
57
65template <class... T>
67
75template <class... T>
77
85template <class... T>
87
88namespace utils
89{
91template <class... T>
92bool sizes_same(T &&... shapes);
93
101template <class... T>
102bool sizes_broadcastable(T &&... shapes);
103
107template <class... T>
109
123
124template <typename... S>
126
136
146
147std::string indentation(int level, int indent = 2);
148
149template <typename T>
150std::string stringify(const T & t);
151
152namespace details
153{
154template <typename... S>
155TorchShape add_shapes_impl(TorchShape &, TorchShapeRef, S &&...);
156
157TorchShape add_shapes_impl(TorchShape &);
158} // namespace details
159} // namespace utils
160} // namespace neml2
161
163// Implementation
165
166namespace neml2
167{
168template <class... T>
169bool
171{
172 if (!utils::sizes_same(tensors.base_sizes()...))
173 return false;
174 return utils::sizes_broadcastable(tensors.batch_sizes()...);
175}
176
177template <class... T>
179broadcast_batch_dim(T &&... tensor)
180{
181 return std::max({tensor.batch_dim()...});
182}
183
184template <class... T>
185void
187{
189 "The ",
190 sizeof...(tensors),
191 " operands are not broadcastable. The batch shapes are ",
192 tensors.batch_sizes()...,
193 ", and the base shapes are ",
194 tensors.base_sizes()...);
195}
196
197template <class... T>
198void
200{
201#ifndef NDEBUG
203 "The ",
204 sizeof...(tensors),
205 " operands are not broadcastable. The batch shapes are ",
206 tensors.batch_sizes()...,
207 ", and the base shapes are ",
208 tensors.base_sizes()...);
209#endif
210}
211
212template <class... T>
213void
215{
217 "The ",
218 sizeof...(tensors),
219 " operands are not batch-broadcastable. The batch shapes are ",
220 tensors.batch_sizes()...);
221}
222
223template <class... T>
224void
226{
227#ifndef NDEBUG
229 "The ",
230 sizeof...(tensors),
231 " operands are not batch-broadcastable. The batch shapes are ",
232 tensors.batch_sizes()...);
233#endif
234}
235namespace utils
236{
237template <class... T>
238bool
240{
241 auto all_shapes = std::vector<TorchShapeRef>{shapes...};
242 for (size_t i = 0; i < all_shapes.size() - 1; i++)
243 if (all_shapes[i] != all_shapes[i + 1])
244 return false;
245 return true;
246}
247
248template <class... T>
249bool
251{
252 auto dim = std::max({shapes.size()...});
253 auto all_shapes_padded = std::vector<TorchShape>{pad_prepend(shapes, dim)...};
254
255 for (size_t i = 0; i < dim; i++)
256 {
257 TorchSize max_sz = 1;
258 for (const auto & s : all_shapes_padded)
259 {
260 if (max_sz == 1)
261 {
262 neml_assert_dbg(s[i] > 0, "Found a size equal or less than 0.");
263 if (s[i] > 1)
264 max_sz = s[i];
265 }
266 else if (s[i] != 1 && s[i] != max_sz)
267 return false;
268 }
269 }
270
271 return true;
272}
273
274template <class... T>
277{
278 neml_assert_dbg(sizes_broadcastable(shapes...), "Shapes not broadcastable: ", shapes...);
279
280 auto dim = std::max({shapes.size()...});
281 auto all_shapes_padded = std::vector<TorchShape>{pad_prepend(shapes, dim)...};
282 auto bshape = TorchShape(dim, 1);
283
284 for (size_t i = 0; i < dim; i++)
285 for (const auto & s : all_shapes_padded)
286 if (s[i] > bshape[i])
287 bshape[i] = s[i];
288
289 return bshape;
290}
291
292template <typename... S>
295{
297 return details::add_shapes_impl(net, std::forward<S>(shape)...);
298}
299
300template <typename T>
301std::string
302stringify(const T & t)
303{
304 std::ostringstream os;
305 os << t;
306 return os.str();
307}
308
309template <>
310inline std::string
311stringify(const bool & t)
312{
313 return t ? "true" : "false";
314}
315
316namespace details
317{
318template <typename... S>
320add_shapes_impl(TorchShape & net, TorchShapeRef s, S &&... rest)
321{
322 net.insert(net.end(), s.begin(), s.end());
323 return add_shapes_impl(net, std::forward<S>(rest)...);
324}
325} // namespace details
326} // namespace utils
327} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
TorchShape pad_append(TorchShapeRef s, TorchSize dim, TorchSize pad)
Pad shape s to dimension dim by appending sizes of pad.
Definition utils.cxx:47
bool sizes_broadcastable(T &&... shapes)
Check if the shapes are broadcastable.
Definition utils.h:250
bool sizes_same(T &&... shapes)
Check if all shapes are the same.
Definition utils.h:239
TorchShape broadcast_sizes(T &&... shapes)
Return the broadcast shape of all the shapes.
Definition utils.h:276
std::string stringify(const T &t)
Definition utils.h:302
TorchSize storage_size(TorchShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:32
std::string indentation(int level, int indent)
Definition utils.cxx:56
TorchShape pad_prepend(TorchShapeRef s, TorchSize dim, TorchSize pad)
Pad shape s to dimension dim by prepending sizes of pad.
Definition utils.cxx:39
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
void neml_assert_broadcastable(T &&...)
A helper function to assert that all tensors are broadcastable.
int64_t TorchSize
Definition types.h:35
void neml_assert_batch_broadcastable_dbg(T &&...)
A helper function to assert that (in Debug mode) all tensors are batch-broadcastable.
std::vector< TorchSize > TorchShape
Definition types.h:36
bool broadcastable(T &&... tensors)
Definition utils.h:170
TorchSize broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
void neml_assert_batch_broadcastable(T &&...)
A helper function to assert that all tensors are batch-broadcastable.
torch::IntArrayRef TorchShapeRef
Definition types.h:37
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73