NEML2 1.4.0
Loading...
Searching...
No Matches
utils.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/misc/types.h"
28#include "neml2/misc/error.h"
29
30namespace neml2
31{
32
38template <class... T>
39bool broadcastable(T &&... tensors);
40
46template <class... T>
48
56template <class... T>
58
66template <class... T>
68
76template <class... T>
78
86template <class... T>
88
89namespace utils
90{
92std::string demangle(const char * name);
93
95template <class... T>
96bool sizes_same(T &&... shapes);
97
105template <class... T>
106bool sizes_broadcastable(T &&... shapes);
107
111template <class... T>
113
127
128template <typename... S>
130
140
150
151std::string indentation(int level, int indent = 2);
152
153template <typename T>
154std::string stringify(const T & t);
155
156namespace details
157{
158template <typename... S>
159TensorShape add_shapes_impl(TensorShape &, TensorShapeRef, S &&...);
160
161TensorShape add_shapes_impl(TensorShape &);
162} // namespace details
163} // namespace utils
164} // namespace neml2
165
167// Implementation
169
170namespace neml2
171{
172template <class... T>
173bool
175{
176 if (!utils::sizes_same(tensors.base_sizes()...))
177 return false;
178 return utils::sizes_broadcastable(tensors.batch_sizes()...);
179}
180
181template <class... T>
182Size
183broadcast_batch_dim(T &&... tensor)
184{
185 return std::max({tensor.batch_dim()...});
186}
187
188template <class... T>
189void
191{
193 "The ",
194 sizeof...(tensors),
195 " operands are not broadcastable. The batch shapes are ",
196 tensors.batch_sizes()...,
197 ", and the base shapes are ",
198 tensors.base_sizes()...);
199}
200
201template <class... T>
202void
204{
205#ifndef NDEBUG
207 "The ",
208 sizeof...(tensors),
209 " operands are not broadcastable. The batch shapes are ",
210 tensors.batch_sizes()...,
211 ", and the base shapes are ",
212 tensors.base_sizes()...);
213#endif
214}
215
216template <class... T>
217void
219{
221 "The ",
222 sizeof...(tensors),
223 " operands are not batch-broadcastable. The batch shapes are ",
224 tensors.batch_sizes()...);
225}
226
227template <class... T>
228void
230{
231#ifndef NDEBUG
233 "The ",
234 sizeof...(tensors),
235 " operands are not batch-broadcastable. The batch shapes are ",
236 tensors.batch_sizes()...);
237#endif
238}
239namespace utils
240{
241template <class... T>
242bool
244{
245 auto all_shapes = std::vector<TensorShapeRef>{shapes...};
246 for (size_t i = 0; i < all_shapes.size() - 1; i++)
247 if (all_shapes[i] != all_shapes[i + 1])
248 return false;
249 return true;
250}
251
252template <class... T>
253bool
255{
256 auto dim = std::max({shapes.size()...});
257 auto all_shapes_padded = std::vector<TensorShape>{pad_prepend(shapes, dim)...};
258
259 for (size_t i = 0; i < dim; i++)
260 {
261 Size max_sz = 1;
262 for (const auto & s : all_shapes_padded)
263 {
264 if (max_sz == 1)
265 {
266 neml_assert_dbg(s[i] > 0, "Found a size equal or less than 0.");
267 if (s[i] > 1)
268 max_sz = s[i];
269 }
270 else if (s[i] != 1 && s[i] != max_sz)
271 return false;
272 }
273 }
274
275 return true;
276}
277
278template <class... T>
281{
282 neml_assert_dbg(sizes_broadcastable(shapes...), "Shapes not broadcastable: ", shapes...);
283
284 auto dim = std::max({shapes.size()...});
285 auto all_shapes_padded = std::vector<TensorShape>{pad_prepend(shapes, dim)...};
286 auto bshape = TensorShape(dim, 1);
287
288 for (size_t i = 0; i < dim; i++)
289 for (const auto & s : all_shapes_padded)
290 if (s[i] > bshape[i])
291 bshape[i] = s[i];
292
293 return bshape;
294}
295
296template <typename... S>
299{
301 return details::add_shapes_impl(net, std::forward<S>(shape)...);
302}
303
304template <typename T>
305std::string
306stringify(const T & t)
307{
308 std::ostringstream os;
309 os << t;
310 return os.str();
311}
312
313template <>
314inline std::string
315stringify(const bool & t)
316{
317 return t ? "true" : "false";
318}
319
320namespace details
321{
322template <typename... S>
324add_shapes_impl(TensorShape & net, TensorShapeRef s, S &&... rest)
325{
326 net.insert(net.end(), s.begin(), s.end());
327 return add_shapes_impl(net, std::forward<S>(rest)...);
328}
329} // namespace details
330} // namespace utils
331} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
bool sizes_broadcastable(T &&... shapes)
Check if the shapes are broadcastable.
Definition utils.h:254
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:40
bool sizes_same(T &&... shapes)
Check if all shapes are the same.
Definition utils.h:243
TensorShape pad_prepend(TensorShapeRef s, Size dim, Size pad)
Pad shape s to dimension dim by prepending sizes of pad.
Definition utils.cxx:47
TensorShape broadcast_sizes(T &&... shapes)
Return the broadcast shape of all the shapes.
Definition utils.h:280
TensorShape add_shapes(S &&... shape)
Definition utils.h:298
std::string stringify(const T &t)
Definition utils.h:306
std::string indentation(int level, int indent)
Definition utils.cxx:64
std::string demangle(const char *name)
Demangle a piece of cxx abi type information.
Definition utils.cxx:33
TensorShape pad_append(TensorShapeRef s, Size dim, Size pad)
Pad shape s to dimension dim by appending sizes of pad.
Definition utils.cxx:55
Definition CrossRef.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:76
void neml_assert_broadcastable(T &&...)
A helper function to assert that all tensors are broadcastable.
void neml_assert_batch_broadcastable_dbg(T &&...)
A helper function to assert that (in Debug mode) all tensors are batch-broadcastable.
Size broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
bool broadcastable(T &&... tensors)
Definition utils.h:174
torch::SmallVector< Size > TensorShape
Definition types.h:34
void neml_assert_batch_broadcastable(T &&...)
A helper function to assert that all tensors are batch-broadcastable.
int64_t Size
Definition types.h:33
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
torch::IntArrayRef TensorShapeRef
Definition types.h:35
void neml_assert(bool assertion, Args &&... args)
Definition error.h:64