27#include "neml2/tensors/BatchTensorBase.h"
113BatchTensor
operator*(
const BatchTensor & a,
const BatchTensor & b);
NEML2's enhanced tensor type.
Definition BatchTensorBase.h:46
BatchTensorBase()=default
Default constructor.
Definition BatchTensor.h:32
static BatchTensor zeros(const TorchShapeRef &base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with zeros given base shape.
Definition BatchTensor.cxx:45
static BatchTensor full(const TorchShapeRef &base_shape, Real init, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition BatchTensor.cxx:75
static BatchTensor identity(TorchSize n, const torch::TensorOptions &options=default_tensor_options())
Unbatched identity tensor.
Definition BatchTensor.cxx:91
static BatchTensor ones(const TorchShapeRef &base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with ones given base shape.
Definition BatchTensor.cxx:60
static BatchTensor empty(const TorchShapeRef &base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition BatchTensor.cxx:30
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
BatchTensor bvv(const BatchTensor &a, const BatchTensor &b)
Batched vector-vector (dot) product.
Definition BatchTensor.cxx:137
BatchTensor bmv(const BatchTensor &a, const BatchTensor &v)
Batched matrix-vector product.
Definition BatchTensor.cxx:122
BatchTensor bmm(const BatchTensor &a, const BatchTensor &b)
Batched matrix-matrix product.
Definition BatchTensor.cxx:107
Definition CrossRef.cxx:32
BatchTensor operator*(const BatchTensor &a, const BatchTensor &b)
Definition BatchTensor.cxx:153
const torch::TensorOptions default_tensor_options()
Definition types.cxx:30
int64_t TorchSize
Definition types.h:35
double Real
Definition types.h:33
torch::IntArrayRef TorchShapeRef
Definition types.h:37