27#include "neml2/misc/types.h"
28#include "neml2/misc/error.h"
124template <
typename...
S>
154template <
typename...
S>
181 return std::max({tensor.batch_dim()...});
191 " operands are not broadcastable. The batch shapes are ",
193 ", and the base shapes are ",
205 " operands are not broadcastable. The batch shapes are ",
207 ", and the base shapes are ",
219 " operands are not batch-broadcastable. The batch shapes are ",
231 " operands are not batch-broadcastable. The batch shapes are ",
255 for (
size_t i = 0;
i <
dim;
i++)
284 for (
size_t i = 0;
i <
dim;
i++)
292template <
typename...
S>
297 return details::add_shapes_impl(
net, std::forward<S>(
shape)...);
304 std::ostringstream
os;
313 return t ?
"true" :
"false";
318template <
typename...
S>
322 net.insert(
net.end(),
s.begin(),
s.end());
323 return add_shapes_impl(
net, std::forward<S>(
rest)...);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
TorchShape pad_append(TorchShapeRef s, TorchSize dim, TorchSize pad)
Pad shape s to dimension dim by appending sizes of pad.
Definition utils.cxx:47
bool sizes_broadcastable(T &&... shapes)
Check if the shapes are broadcastable.
Definition utils.h:250
bool sizes_same(T &&... shapes)
Check if all shapes are the same.
Definition utils.h:239
TorchShape broadcast_sizes(T &&... shapes)
Return the broadcast shape of all the shapes.
Definition utils.h:276
std::string stringify(const T &t)
Definition utils.h:302
TorchSize storage_size(TorchShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:32
std::string indentation(int level, int indent)
Definition utils.cxx:56
TorchShape pad_prepend(TorchShapeRef s, TorchSize dim, TorchSize pad)
Pad shape s to dimension dim by prepending sizes of pad.
Definition utils.cxx:39
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
void neml_assert_broadcastable(T &&...)
A helper function to assert that all tensors are broadcastable.
int64_t TorchSize
Definition types.h:35
void neml_assert_batch_broadcastable_dbg(T &&...)
A helper function to assert that (in Debug mode) all tensors are batch-broadcastable.
std::vector< TorchSize > TorchShape
Definition types.h:36
bool broadcastable(T &&... tensors)
Definition utils.h:170
TorchSize broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
void neml_assert_batch_broadcastable(T &&...)
A helper function to assert that all tensors are batch-broadcastable.
torch::IntArrayRef TorchShapeRef
Definition types.h:37
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.
void neml_assert(bool assertion, Args &&... args)
Definition error.h:73