NEML2 1.4.0
Loading...
Searching...
No Matches
PrimitiveTensor.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/tensors/Tensor.h"
28
29namespace neml2
30{
36template <class Derived, Size... S>
37class PrimitiveTensor : public TensorBase<Derived>
38{
39public:
41 static inline const TensorShape const_base_sizes = {S...};
42
44 static constexpr Size const_base_dim = sizeof...(S);
45
47 static inline const Size const_base_storage = utils::storage_size({S...});
48
50 PrimitiveTensor() = default;
51
53 explicit PrimitiveTensor(const torch::Tensor & tensor, Size batch_dim);
54
56 PrimitiveTensor(const torch::Tensor & tensor);
57
59 operator Tensor() const;
60
62 [[nodiscard]] static Derived
63 empty(const torch::TensorOptions & options = default_tensor_options());
65 [[nodiscard]] static Derived
67 const torch::TensorOptions & options = default_tensor_options());
69 [[nodiscard]] static Derived
70 zeros(const torch::TensorOptions & options = default_tensor_options());
72 [[nodiscard]] static Derived
74 const torch::TensorOptions & options = default_tensor_options());
76 [[nodiscard]] static Derived
77 ones(const torch::TensorOptions & options = default_tensor_options());
79 [[nodiscard]] static Derived
80 ones(TensorShapeRef batch_shape, const torch::TensorOptions & options = default_tensor_options());
82 [[nodiscard]] static Derived
83 full(Real init, const torch::TensorOptions & options = default_tensor_options());
85 [[nodiscard]] static Derived
87 Real init,
88 const torch::TensorOptions & options = default_tensor_options());
89
91 [[nodiscard]] static Tensor identity_map(const torch::TensorOptions &)
92 {
93 throw NEMLException("Not implemented");
94 }
95};
96
98// Implementations
100
101template <class Derived, Size... S>
102PrimitiveTensor<Derived, S...>::PrimitiveTensor(const torch::Tensor & tensor, Size batch_dim)
103 : TensorBase<Derived>(tensor, batch_dim)
104{
105 neml_assert_dbg(this->base_sizes() == const_base_sizes,
106 "Base shape mismatch: trying to create a tensor with base shape ",
108 " from a tensor with base shape ",
109 this->base_sizes());
110}
111
112template <class Derived, Size... S>
114 : TensorBase<Derived>(tensor, tensor.dim() - const_base_dim)
115{
116 neml_assert_dbg(this->base_sizes() == const_base_sizes,
117 "Base shape mismatch: trying to create a tensor with base shape ",
119 " from a tensor with shape ",
120 tensor.sizes());
121}
122
123template <class Derived, Size... S>
124PrimitiveTensor<Derived, S...>::operator Tensor() const
125{
126 return Tensor(*this, this->batch_dim());
127}
128
129template <class Derived, Size... S>
131PrimitiveTensor<Derived, S...>::empty(const torch::TensorOptions & options)
132{
133 return Derived(torch::empty(const_base_sizes, options), 0);
134}
135
136template <class Derived, Size... S>
139 const torch::TensorOptions & options)
140{
141 return Derived(torch::empty(utils::add_shapes(batch_shape, const_base_sizes), options),
142 batch_shape.size());
143}
144
145template <class Derived, Size... S>
147PrimitiveTensor<Derived, S...>::zeros(const torch::TensorOptions & options)
148{
149 return Derived(torch::zeros(const_base_sizes, options), 0);
150}
151
152template <class Derived, Size... S>
155 const torch::TensorOptions & options)
156{
157 return Derived(torch::zeros(utils::add_shapes(batch_shape, const_base_sizes), options),
158 batch_shape.size());
159}
160
161template <class Derived, Size... S>
163PrimitiveTensor<Derived, S...>::ones(const torch::TensorOptions & options)
164{
165 return Derived(torch::ones(const_base_sizes, options), 0);
166}
167
168template <class Derived, Size... S>
171 const torch::TensorOptions & options)
172{
173 return Derived(torch::ones(utils::add_shapes(batch_shape, const_base_sizes), options),
174 batch_shape.size());
175}
176
177template <class Derived, Size... S>
179PrimitiveTensor<Derived, S...>::full(Real init, const torch::TensorOptions & options)
180{
181 return Derived(torch::full(const_base_sizes, init, options), 0);
182}
183
184template <class Derived, Size... S>
187 Real init,
188 const torch::TensorOptions & options)
189{
190 return Derived(torch::full(utils::add_shapes(batch_shape, const_base_sizes), init, options),
191 batch_shape.size());
192}
193} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
CrossRef()=default
Definition error.h:33
PrimitiveTensor inherits from TensorBase and additionally templates on the base shape.
Definition PrimitiveTensor.h:38
static Derived full(Real init, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with a given value given base shape.
Definition PrimitiveTensor.h:179
static Derived empty(TensorShapeRef batch_shape, const torch::TensorOptions &options=default_tensor_options())
Empty tensor given batch shape.
Definition PrimitiveTensor.h:138
PrimitiveTensor(const torch::Tensor &tensor)
Construct from another torch::Tensor and infer batch dimension.
Definition PrimitiveTensor.h:113
PrimitiveTensor(const torch::Tensor &tensor, Size batch_dim)
Construct from another torch::Tensor given batch dimension.
Definition PrimitiveTensor.h:102
static Derived empty(const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor.
Definition PrimitiveTensor.h:131
static Derived full(TensorShapeRef batch_shape, Real init, const torch::TensorOptions &options=default_tensor_options())
Full tensor given batch shape.
Definition PrimitiveTensor.h:186
static const TensorShape const_base_sizes
The base shape.
Definition PrimitiveTensor.h:41
static Derived ones(TensorShapeRef batch_shape, const torch::TensorOptions &options=default_tensor_options())
Unit tensor given batch shape.
Definition PrimitiveTensor.h:170
static Derived zeros(const torch::TensorOptions &options=default_tensor_options())
Unbatched zero tensor.
Definition PrimitiveTensor.h:147
static constexpr Size const_base_dim
The base dim.
Definition PrimitiveTensor.h:44
static Derived ones(const torch::TensorOptions &options=default_tensor_options())
Unbatched unit tensor.
Definition PrimitiveTensor.h:163
static const Size const_base_storage
The base storage.
Definition PrimitiveTensor.h:47
static Derived zeros(TensorShapeRef batch_shape, const torch::TensorOptions &options=default_tensor_options())
Zero tensor given batch shape.
Definition PrimitiveTensor.h:154
static Tensor identity_map(const torch::TensorOptions &)
Derived tensor classes should define identity_map where appropriate.
Definition PrimitiveTensor.h:91
PrimitiveTensor()=default
Default constructor.
NEML2's enhanced tensor type.
Definition TensorBase.h:46
Size batch_dim() const
Return the number of batch dimensions.
Definition TensorBase.cxx:142
TensorShapeRef base_sizes() const
Return the base size.
Definition TensorBase.cxx:170
Definition Tensor.h:32
Size storage_size(TensorShapeRef shape)
The flattened storage size of a tensor with given shape.
Definition utils.cxx:40
TensorShape add_shapes(S &&... shape)
Definition utils.h:298
Definition CrossRef.cxx:30
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:76
torch::TensorOptions & default_tensor_options()
Definition types.cxx:30
double Real
Definition types.h:31
torch::SmallVector< Size > TensorShape
Definition types.h:34
int64_t Size
Definition types.h:33
torch::IntArrayRef TensorShapeRef
Definition types.h:35