27#include "neml2/tensors/FixedDimTensor.h"
56 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
57 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
63 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
69 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
70 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
79 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
80 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
86 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
92 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
93 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
102 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
103 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
109 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
115 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
116 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
123Scalar
operator*(
const Scalar & a,
const Scalar & b);
127 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
128 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
134 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
140 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
141 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
147 net.insert(
net.end(), b.base_dim(), torch::indexing::None);
155 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
156 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>, Derived>>>
162 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
168 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
169 typename =
typename std::enable_if_t<std::is_base_of_v<BatchTensorBase<Derived>,
Derived>>>
175 net.insert(
net.end(),
n.base_dim(), torch::indexing::None);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
FixedDimTensor inherits from BatchTensorBase and additionally templates on the base shape.
Definition FixedDimTensor.h:38
FixedDimTensor()=default
Default constructor.
The (logical) scalar.
Definition Scalar.h:38
Scalar(Real init, const torch::TensorOptions &options)
Definition Scalar.cxx:29
static Scalar identity_map(const torch::TensorOptions &options=default_tensor_options())
The derivative of a Scalar with respect to itself.
Definition Scalar.cxx:35
Derived pow(const Derived &a, const Real &n)
Definition BatchTensorBase.h:332
Definition CrossRef.cxx:32
Derived operator-(const Derived &a, const Real &b)
Definition BatchTensorBase.h:256
BatchTensor operator*(const BatchTensor &a, const BatchTensor &b)
Definition BatchTensor.cxx:153
const torch::TensorOptions default_tensor_options()
Definition types.cxx:30
void neml_assert_batch_broadcastable_dbg(T &&...)
A helper function to assert that (in Debug mode) all tensors are batch-broadcastable.
TorchSize broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
double Real
Definition types.h:33
Derived operator+(const Derived &a, const Real &b)
Definition BatchTensorBase.h:228
std::vector< at::indexing::TensorIndex > TorchSlice
Definition types.h:39
Derived operator/(const Derived &a, const Real &b)
Definition BatchTensorBase.h:302
Scalar abs(const Scalar &a)
Absolute value.
Definition Scalar.cxx:48