25#include "neml2/tensors/BatchTensor.h"
38 const torch::TensorOptions & options)
53 const torch::TensorOptions & options)
68 const torch::TensorOptions & options)
84 const torch::TensorOptions & options)
99 const torch::TensorOptions & options)
111 "The first tensor in bmm has base dimension ",
115 "The second tensor in bmm has base dimension ",
126 "The first tensor in bmv has base dimension ",
130 "The second tensor in bmv has base dimension ",
141 "The first tensor in bvv has base dimension ",
145 "The second tensor in bvv has base dimension ",
Definition BatchTensor.h:32
static BatchTensor zeros(const TorchShapeRef &base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with zeros given base shape.
Definition BatchTensor.cxx:45
static BatchTensor full(const TorchShapeRef &base_shape, Real init, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition BatchTensor.cxx:75
static BatchTensor identity(TorchSize n, const torch::TensorOptions &options=default_tensor_options())
Unbatched identity tensor.
Definition BatchTensor.cxx:91
static BatchTensor ones(const TorchShapeRef &base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with ones given base shape.
Definition BatchTensor.cxx:60
static BatchTensor empty(const TorchShapeRef &base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition BatchTensor.cxx:30
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
BatchTensor bvv(const BatchTensor &a, const BatchTensor &b)
Batched vector-vector (dot) product.
Definition BatchTensor.cxx:137
BatchTensor bmv(const BatchTensor &a, const BatchTensor &v)
Batched matrix-vector product.
Definition BatchTensor.cxx:122
BatchTensor bmm(const BatchTensor &a, const BatchTensor &b)
Batched matrix-matrix product.
Definition BatchTensor.cxx:107
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
BatchTensor operator*(const BatchTensor &a, const BatchTensor &b)
Definition BatchTensor.cxx:153
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
int64_t TorchSize
Definition types.h:35
void neml_assert_batch_broadcastable_dbg(T &&...)
A helper function to assert that (in Debug mode) all tensors are batch-broadcastable.
TorchSize broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
double Real
Definition types.h:33
torch::IntArrayRef TorchShapeRef
Definition types.h:37
void neml_assert_broadcastable_dbg(T &&...)
A helper function to assert (in Debug mode) that all tensors are broadcastable.