NEML2 1.4.0
Loading...
Searching...
No Matches
LabeledTensor.h
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/misc/types.h"
28#include "neml2/tensors/LabeledAxis.h"
29#include "neml2/tensors/BatchTensor.h"
30
31namespace neml2
32{
42template <class Derived, TorchSize D>
44{
45public:
47 LabeledTensor() = default;
48
50 LabeledTensor(const torch::Tensor & tensor,
52 const std::vector<const LabeledAxis *> & axes);
53
55 LabeledTensor(const BatchTensor & tensor, const std::vector<const LabeledAxis *> & axes);
56
59
61 void operator=(const Derived & other);
62
64 // Should we mark it explicit?
65 operator BatchTensor() const;
66
68 // Should we mark it explicit?
69 operator torch::Tensor() const;
70
72 [[nodiscard]] static Derived
74 const std::vector<const LabeledAxis *> & axes,
75 const torch::TensorOptions & options = default_tensor_options());
76
78 [[nodiscard]] static Derived empty_like(const Derived & other);
79
81 [[nodiscard]] static Derived
83 const std::vector<const LabeledAxis *> & axes,
84 const torch::TensorOptions & options = default_tensor_options());
85
87 [[nodiscard]] static Derived zeros_like(const Derived & other);
88
90 Derived clone(torch::MemoryFormat memory_format = torch::MemoryFormat::Contiguous) const;
91
93 template <typename T>
94 void copy_(const T & other);
95
97 Derived detach() const;
98
100 void detach_();
101
103 void zero_();
104
107 const BatchTensor & tensor() const { return _tensor; }
108 BatchTensor & tensor() { return _tensor; }
110
112 torch::TensorOptions options() const { return _tensor.options(); }
113
115 TorchSize batch_dim() const;
116
118 TorchSize base_dim() const;
119
122
125
127 const std::vector<const LabeledAxis *> & axes() const { return _axes; }
128
130 const LabeledAxis & axis(TorchSize i = 0) const { return *_axes[i]; }
131
133 template <typename... S>
135
138
140 template <typename... S>
142
145 template <typename... S>
147
149 Derived slice(TorchSize i, const std::string & name) const;
150
152 template <typename... S>
153 Derived block(S &&... names) const;
154
156 Derived batch_index(TorchSlice indices) const;
157
159 void batch_index_put(TorchSlice indices, const torch::Tensor & other);
160
162 BatchTensor base_index(TorchSlice indices) const;
163
165 void base_index_put(TorchSlice indices, const torch::Tensor & other);
166
168 template <typename T>
170 {
171 typedef T type;
172 };
173
175 template <typename T, typename... S>
176 typename variable_type<T>::type get(S &&... names) const
177 {
178 return T((*this)(names...).view(utils::add_shapes(batch_sizes(), T::const_base_sizes)),
179 batch_dim());
180 }
181
183 template <typename T, typename... S>
185 {
186 return T(((*this)(names...))
187 .reshape(utils::add_shapes(this->batch_sizes(), -1, T::const_base_sizes)),
188 this->batch_dim() + sizeof...(names));
189 }
190
192 template <typename T, typename... S>
193 void set(const BatchTensorBase<T> & value, S &&... names)
194 {
195 (*this)(names...).index_put_(
196 {torch::indexing::None},
197 value.reshape(utils::add_shapes(value.batch_sizes(), storage_size(names...))));
198 }
199
201 template <typename T, typename... S>
202 void set_list(const BatchTensorBase<T> & value, S &&... names)
203 {
204 this->set(BatchTensor(value, value.batch_dim() - sizeof...(names)), names...);
205 }
206
208 Derived operator-() const;
209
211 Derived to(const torch::TensorOptions & options) const;
212
213protected:
216
218 // Urgh, I can't use const references here as the elements of a vector has to be "assignable".
219 std::vector<const LabeledAxis *> _axes;
220
221private:
222 template <std::size_t... I, typename... S>
223 TorchSlice slice_indices_impl(std::index_sequence<I...>, S &&... names) const;
224
225 template <std::size_t... I, typename... S>
226 TorchShape storage_size_impl(std::index_sequence<I...>, S &&... names) const;
227
228 template <std::size_t... I, typename... S>
229 Derived block_impl(std::index_sequence<I...>, S &&... names) const;
230};
231
232template <class Derived, TorchSize D>
233template <typename T>
234void
236{
237 _tensor.copy_(other);
238}
239
240template <class Derived, TorchSize D>
241template <typename... S>
244{
245 static_assert(sizeof...(names) == D, "Wrong labaled dimesion in LabeledTensor::slice_indices");
246 return slice_indices_impl(std::make_index_sequence<sizeof...(names)>(),
247 std::forward<S>(names)...);
248}
249
250template <class Derived, TorchSize D>
251template <std::size_t... I, typename... S>
253LabeledTensor<Derived, D>::slice_indices_impl(std::index_sequence<I...>, S &&... names) const
254{
255 return {_axes[I]->indices(names)...};
256}
257
258template <class Derived, TorchSize D>
259template <typename... S>
262{
263 static_assert(sizeof...(names) == D, "Wrong labaled dimesion in LabeledTensor::storage_size");
264 return storage_size_impl(std::make_index_sequence<D>(), std::forward<S>(names)...);
265}
266
267template <class Derived, TorchSize D>
268template <std::size_t... I, typename... S>
270LabeledTensor<Derived, D>::storage_size_impl(std::index_sequence<I...>, S &&... names) const
271{
272 return {_axes[I]->storage_size(names)...};
273}
274
275template <class Derived, TorchSize D>
276template <typename... S>
277BatchTensor
279{
280 static_assert(sizeof...(names) == D, "Wrong labeled dimension in LabeledTensor::operator()");
281 return base_index(slice_indices(names...));
282}
283
284template <class Derived, TorchSize D>
285template <typename... S>
288{
289 return block_impl(std::make_index_sequence<sizeof...(names)>(), std::forward<S>(names)...);
290}
291
292template <class Derived, TorchSize D>
293template <std::size_t... I, typename... S>
295LabeledTensor<Derived, D>::block_impl(std::index_sequence<I...>, S &&... names) const
296{
297 TorchSlice idx = {_axes[I]->indices(names)...};
298 std::vector<const LabeledAxis *> new_axes = {&_axes[I]->subaxis(names)...};
299 return Derived(_tensor.base_index(idx), new_axes);
300}
301} // namespace neml2
Definition BatchTensor.h:32
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
CrossRef()=default
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
The primary data structure in NEML2 for working with labeled tensor views.
Definition LabeledTensor.h:44
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Clone this LabeledTensor.
Definition LabeledTensor.cxx:131
const LabeledAxis & axis(TorchSize i=0) const
Get a specific labeled axis.
Definition LabeledTensor.h:130
void operator=(const Derived &other)
Assignment operator.
Definition LabeledTensor.cxx:67
TorchSlice slice_indices(S &&... names) const
How to slice the tensor given the names on each axis.
Definition LabeledTensor.h:243
LabeledTensor()=default
Default constructor.
void zero_()
Zero out this tensor.
Definition LabeledTensor.cxx:152
Derived slice(TorchSize i, const std::string &name) const
Slice the tensor on the given dimension by a single variable or sub-axis.
Definition LabeledTensor.cxx:194
BatchTensor operator()(S &&... names) const
Definition LabeledTensor.h:278
TorchShapeRef batch_sizes() const
Return the batch size.
Definition LabeledTensor.cxx:173
Derived detach() const
Return a copy without gradient graphs.
Definition LabeledTensor.cxx:138
static Derived zeros(TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new storage with zeros.
Definition LabeledTensor.cxx:109
Derived operator-() const
Negation.
Definition LabeledTensor.cxx:235
TorchSize batch_dim() const
Return the number of batch dimensions.
Definition LabeledTensor.cxx:159
Derived batch_index(TorchSlice indices) const
Get a batch.
Definition LabeledTensor.cxx:207
Derived block(S &&... names) const
Get the sub-block labeled by the given sub-axis names.
Definition LabeledTensor.h:287
torch::TensorOptions options() const
Get the tensor options.
Definition LabeledTensor.h:112
TorchShapeRef base_sizes() const
Return the base size.
Definition LabeledTensor.cxx:180
BatchTensor base_index(TorchSlice indices) const
Return an index sliced on the batch dimensions.
Definition LabeledTensor.cxx:221
Derived to(const torch::TensorOptions &options) const
Change tensor options.
Definition LabeledTensor.cxx:242
void base_index_put(TorchSlice indices, const torch::Tensor &other)
Set a index sliced on the batch dimensions to a value.
Definition LabeledTensor.cxx:228
std::vector< const LabeledAxis * > _axes
The labeled axes of this tensor.
Definition LabeledTensor.h:219
static Derived zeros_like(const Derived &other)
Setup new storage with zeros like another LabeledTensor.
Definition LabeledTensor.cxx:124
const BatchTensor & tensor() const
Definition LabeledTensor.h:107
BatchTensor _tensor
The tensor.
Definition LabeledTensor.h:215
void set(const BatchTensorBase< T > &value, S &&... names)
Set and interpret the input as an object.
Definition LabeledTensor.h:193
static Derived empty(TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new empty storage.
Definition LabeledTensor.cxx:87
static Derived empty_like(const Derived &other)
Setup new empty storage like another LabeledTensor.
Definition LabeledTensor.cxx:102
TorchSize base_dim() const
Return the number of base dimensions.
Definition LabeledTensor.cxx:166
void detach_()
Detach from gradient graphs.
Definition LabeledTensor.cxx:145
TorchShapeRef storage_size() const
The shape of the entire LabeledTensor.
Definition LabeledTensor.cxx:187
variable_type< T >::type get_list(S &&... names) const
Get and interpret the view as a list of objects.
Definition LabeledTensor.h:184
void set_list(const BatchTensorBase< T > &value, S &&... names)
Set and interpret the input as a list of objects.
Definition LabeledTensor.h:202
void batch_index_put(TorchSlice indices, const torch::Tensor &other)
Set a index sliced on the batch dimensions to a value.
Definition LabeledTensor.cxx:214
const std::vector< const LabeledAxis * > & axes() const
Get all the labeled axes.
Definition LabeledTensor.h:127
BatchTensor & tensor()
Definition LabeledTensor.h:108
TorchShape storage_size(S &&... names) const
The shape of a sub-block specified by the names on each dimension.
Definition LabeledTensor.h:261
variable_type< T >::type get(S &&... names) const
Get and interpret the view as an object.
Definition LabeledTensor.h:176
void copy_(const T &other)
Copy the value from another tensor.
Definition LabeledTensor.h:235
TorchShape add_shapes(S &&... shape)
Definition utils.h:294
Definition CrossRef.cxx:32
const torch::TensorOptions default_tensor_options()
Definition types.cxx:30
int64_t TorchSize
Definition types.h:33
std::vector< TorchSize > TorchShape
Definition types.h:34
torch::IntArrayRef TorchShapeRef
Definition types.h:35
std::vector< at::indexing::TensorIndex > TorchSlice
Definition types.h:37
Template setup for appropriate variable types.
Definition LabeledTensor.h:170
T type
Definition LabeledTensor.h:171