27#include "neml2/tensors/PrimitiveTensor.h"
55 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
56 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>,
Derived>>>
62 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
66template <
class Derived,
67 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
68 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
75template <
class Derived,
76 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
77 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
83 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
87template <
class Derived,
88 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
89 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
96template <
class Derived,
97 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
98 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
104 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
108template <
class Derived,
109 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
110 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
117Scalar
operator*(
const Scalar & a,
const Scalar & b);
119template <
class Derived,
120 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
121 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
127 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
131template <
class Derived,
132 typename =
typename std::enable_if_t<!std::is_same_v<Derived, Scalar>>,
133 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
139 net.insert(
net.end(), b.base_dim(), torch::indexing::None);
145template <
class Derived,
146 typename =
typename std::enable_if_t<std::is_base_of_v<TensorBase<Derived>, Derived>>>
152 net.insert(
net.end(), a.base_dim(), torch::indexing::None);
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
PrimitiveTensor inherits from TensorBase and additionally templates on the base shape.
Definition PrimitiveTensor.h:38
PrimitiveTensor()=default
Default constructor.
The (logical) scalar.
Definition Scalar.h:38
Scalar(Real init, const torch::TensorOptions &options)
Definition Scalar.cxx:29
static Scalar identity_map(const torch::TensorOptions &options=default_tensor_options())
The derivative of a Scalar with respect to itself.
Definition Scalar.cxx:35
torch::SmallVector< TensorIndex > TensorIndices
Definition types.h:41
Tensor pow(const Real &a, const Tensor &n)
Definition math.cxx:309
Definition CrossRef.cxx:30
Vec operator*(const Derived1 &A, const Derived2 &b)
matrix-vector product
Definition R2Base.cxx:233
void neml_assert_batch_broadcastable_dbg(T &&...)
A helper function to assert that (in Debug mode) all tensors are batch-broadcastable.
Derived operator-(const Derived &a, const Scalar &b)
Definition Scalar.h:79
Derived operator+(const Derived &a, const Scalar &b)
Definition Scalar.h:58
Size broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.
torch::TensorOptions & default_tensor_options()
Definition types.cxx:30
double Real
Definition types.h:31
Derived operator/(const Derived &a, const Scalar &b)
Definition Scalar.h:123
Scalar abs(const Scalar &a)
Absolute value.
Definition Scalar.cxx:48