NEML2 1.4.0
Loading...
Searching...
No Matches
LabeledTensor.cxx
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#include "neml2/tensors/LabeledTensor.h"
26
27#include "neml2/tensors/LabeledVector.h"
28#include "neml2/tensors/LabeledMatrix.h"
29#include "neml2/tensors/LabeledTensor3D.h"
30
31namespace neml2
32{
33template <class Derived, TorchSize D>
34LabeledTensor<Derived, D>::LabeledTensor(const torch::Tensor & tensor,
35 TorchSize batch_dim,
36 const std::vector<const LabeledAxis *> & axes)
37 : _tensor(tensor, batch_dim),
38 _axes(axes)
39{
40 neml_assert_dbg(axes.size() == D, "Wrong labeled dimension");
41
42 // Check that the size of the tensor was compatible
43 neml_assert_dbg(base_sizes() == storage_size(), "LabeledTensor does not have the right size");
44}
45
46template <class Derived, TorchSize D>
48 const std::vector<const LabeledAxis *> & axes)
49 : _tensor(tensor),
50 _axes(axes)
51{
52 neml_assert_dbg(axes.size() == D, "Wrong labeled dimension");
53
54 // Check that the size of the tensor was compatible
55 neml_assert_dbg(base_sizes() == storage_size(), "LabeledTensor does not have the right size");
56}
57
58template <class Derived, TorchSize D>
60 : _tensor(other),
61 _axes(other.axes())
62{
63}
64
65template <class Derived, TorchSize D>
66void
68{
69 _tensor = other.tensor();
70 _axes = other.axes();
71}
72
73template <class Derived, TorchSize D>
75{
76 return _tensor;
77}
79template <class Derived, TorchSize D>
81{
82 return _tensor;
83}
84
85template <class Derived, TorchSize D>
88 const std::vector<const LabeledAxis *> & axes,
89 const torch::TensorOptions & options)
92 s.reserve(axes.size());
93 std::transform(axes.begin(),
94 axes.end(),
95 std::back_inserter(s),
96 [](const LabeledAxis * axis) { return axis->storage_size(); });
97 return Derived(BatchTensor::empty(batch_size, s, options), axes);
98}
99
100template <class Derived, TorchSize D>
105}
106
107template <class Derived, TorchSize D>
110 const std::vector<const LabeledAxis *> & axes,
111 const torch::TensorOptions & options)
112{
114 s.reserve(axes.size());
115 std::transform(axes.begin(),
116 axes.end(),
117 std::back_inserter(s),
118 [](const LabeledAxis * axis) { return axis->storage_size(); });
119 return Derived(BatchTensor::zeros(batch_size, s, options), axes);
120}
122template <class Derived, TorchSize D>
125{
127}
128
129template <class Derived, TorchSize D>
132{
133 return Derived(_tensor.clone(memory_format), _axes);
134}
135
136template <class Derived, TorchSize D>
139{
140 return Derived(_tensor.detach(), _axes);
141}
142
143template <class Derived, TorchSize D>
144void
146{
147 _tensor.detach_();
148}
150template <class Derived, TorchSize D>
151void
153{
154 _tensor.zero_();
155}
157template <class Derived, TorchSize D>
160{
161 return _tensor.batch_dim();
163
164template <class Derived, TorchSize D>
167{
168 return D;
169}
170
171template <class Derived, TorchSize D>
174{
175 return _tensor.batch_sizes();
176}
177
178template <class Derived, TorchSize D>
181{
182 return _tensor.base_sizes();
183}
184
185template <class Derived, TorchSize D>
188{
189 return base_sizes();
190}
191
192template <class Derived, TorchSize D>
194LabeledTensor<Derived, D>::slice(TorchSize i, const std::string & name) const
195{
196 TorchSlice idx(base_dim(), torch::indexing::Slice());
197 idx[i] = _axes[i]->indices(name);
198
199 auto new_axes = _axes;
200 new_axes[i] = &_axes[i]->subaxis(name);
201
202 return Derived(_tensor.base_index(idx), new_axes);
203}
204
205template <class Derived, TorchSize D>
209 return Derived(_tensor.batch_index(indices), _axes);
210}
212template <class Derived, TorchSize D>
213void
215{
216 _tensor.batch_index_put(indices, other);
217}
218
219template <class Derived, TorchSize D>
222{
223 return _tensor.base_index(indices);
224}
225
226template <class Derived, TorchSize D>
227void
229{
230 _tensor.base_index_put(indices, other);
231}
232
233template <class Derived, TorchSize D>
236{
237 return Derived(-_tensor, _axes);
238}
239
240template <class Derived, TorchSize D>
242LabeledTensor<Derived, D>::to(const torch::TensorOptions & options) const
243{
244 return Derived(_tensor.to(options), _axes);
245}
246
250}
BatchTensor base_index(const TorchSlice &indices) const
Return an index sliced on the base dimensions.
Definition BatchTensorBase.cxx:193
static BatchTensor zeros_like(const BatchTensor &other)
Zero tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:59
static BatchTensor empty_like(const BatchTensor &other)
Empty tensor like another, i.e. same batch and base shapes, same tensor options, etc.
Definition BatchTensorBase.cxx:52
Definition BatchTensor.h:32
static BatchTensor zeros(const TorchShapeRef &base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched tensor filled with zeros given base shape.
Definition BatchTensor.cxx:45
static BatchTensor empty(const TorchShapeRef &base_shape, const torch::TensorOptions &options=default_tensor_options())
Unbatched empty tensor given base shape.
Definition BatchTensor.cxx:30
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
CrossRef()=default
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
The primary data structure in NEML2 for working with labeled tensor views.
Definition LabeledTensor.h:44
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Clone this LabeledTensor.
Definition LabeledTensor.cxx:131
void operator=(const Derived &other)
Assignment operator.
Definition LabeledTensor.cxx:67
LabeledTensor()=default
Default constructor.
void zero_()
Zero out this tensor.
Definition LabeledTensor.cxx:152
Derived slice(TorchSize i, const std::string &name) const
Slice the tensor on the given dimension by a single variable or sub-axis.
Definition LabeledTensor.cxx:194
TorchShapeRef batch_sizes() const
Return the batch size.
Definition LabeledTensor.cxx:173
Derived detach() const
Return a copy without gradient graphs.
Definition LabeledTensor.cxx:138
static Derived zeros(TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new storage with zeros.
Definition LabeledTensor.cxx:109
Derived operator-() const
Negation.
Definition LabeledTensor.cxx:235
TorchSize batch_dim() const
Return the number of batch dimensions.
Definition LabeledTensor.cxx:159
Derived batch_index(TorchSlice indices) const
Get a batch.
Definition LabeledTensor.cxx:207
TorchShapeRef base_sizes() const
Return the base size.
Definition LabeledTensor.cxx:180
BatchTensor base_index(TorchSlice indices) const
Return an index sliced on the batch dimensions.
Definition LabeledTensor.cxx:221
Derived to(const torch::TensorOptions &options) const
Change tensor options.
Definition LabeledTensor.cxx:242
void base_index_put(TorchSlice indices, const torch::Tensor &other)
Set a index sliced on the batch dimensions to a value.
Definition LabeledTensor.cxx:228
static Derived zeros_like(const Derived &other)
Setup new storage with zeros like another LabeledTensor.
Definition LabeledTensor.cxx:124
static Derived empty(TorchShapeRef batch_shape, const std::vector< const LabeledAxis * > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new empty storage.
Definition LabeledTensor.cxx:87
static Derived empty_like(const Derived &other)
Setup new empty storage like another LabeledTensor.
Definition LabeledTensor.cxx:102
TorchSize base_dim() const
Return the number of base dimensions.
Definition LabeledTensor.cxx:166
void detach_()
Detach from gradient graphs.
Definition LabeledTensor.cxx:145
TorchShapeRef storage_size() const
The shape of the entire LabeledTensor.
Definition LabeledTensor.cxx:187
void batch_index_put(TorchSlice indices, const torch::Tensor &other)
Set a index sliced on the batch dimensions to a value.
Definition LabeledTensor.cxx:214
const std::vector< const LabeledAxis * > & axes() const
Get all the labeled axes.
Definition LabeledTensor.h:127
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
int64_t TorchSize
Definition types.h:33
std::vector< TorchSize > TorchShape
Definition types.h:34
torch::IntArrayRef TorchShapeRef
Definition types.h:35
std::vector< at::indexing::TensorIndex > TorchSlice
Definition types.h:37