27#include "neml2/tensors/BatchTensor.h"
32template <
typename F,
typename T1,
typename T2>
36 return BatchTensor(
f(a, b.list_unsqueeze()), b.batch_dim());
40template <
typename F,
typename T1,
typename T2>
48template <
typename F,
typename T1,
typename T2>
52 return BatchTensor(
f(a.batch_unsqueeze(-1), b.batch_unsqueeze(-2)), a.batch_dim() - 1)
BatchTensor base_movedim(TorchSize d1, TorchSize d2) const
Move two base dimensions.
Definition BatchTensorBase.cxx:308
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
Definition CrossRef.cxx:32
BatchTensor list_derivative_outer_product_ab(F &&f, const T1 &a, const T2 &b)
outer product on lists where both inputs are list tensors
Definition list_tensors.h:50
BatchTensor list_derivative_outer_product_b(F &&f, const T1 &a, const T2 &b)
outer product on lists, where the second input is a list tensor
Definition list_tensors.h:42
BatchTensor list_derivative_outer_product_a(F &&f, const T1 &a, const T2 &b)
outer product on lists, where the first input is a list tensor
Definition list_tensors.h:34