NEML2 1.4.0
Loading...
Searching...
No Matches
LabeledTensor.h
1// Copyright 2024, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#pragma once
26
27#include "neml2/misc/types.h"
28#include "neml2/tensors/LabeledAxis.h"
29#include "neml2/tensors/Tensor.h"
30
31namespace neml2
32{
42template <class Derived, Size D>
44{
45public:
47 LabeledTensor() = default;
48
50 LabeledTensor(const torch::Tensor & tensor, const std::array<const LabeledAxis *, D> & axes);
51
53 LabeledTensor(const Tensor & tensor, const std::array<const LabeledAxis *, D> & axes);
54
57
60
63 operator Tensor() const;
64 operator torch::Tensor() const;
66
68 [[nodiscard]] static Derived
70 const std::array<const LabeledAxis *, D> & axes,
71 const torch::TensorOptions & options = default_tensor_options());
72
74 [[nodiscard]] static Derived
76 const std::array<const LabeledAxis *, D> & axes,
77 const torch::TensorOptions & options = default_tensor_options());
78
81 const Tensor & tensor() const { return _tensor; }
82 Tensor & tensor() { return _tensor; }
84
86 // These methods mirror TensorBase
89 Derived clone(torch::MemoryFormat memory_format = torch::MemoryFormat::Contiguous) const;
91 Derived detach() const;
93 void detach_();
95 Derived to(const torch::TensorOptions & options) const;
97 void copy_(const torch::Tensor & other);
99 void zero_();
101 bool requires_grad() const;
103 void requires_grad_(bool req = true);
105 Derived operator-() const;
107
109 // These methods mirror TensorBase
112 torch::TensorOptions options() const;
114 torch::Dtype scalar_type() const;
116 torch::Device device() const;
118 Size dim() const;
120 TensorShapeRef sizes() const;
122 Size size(Size dim) const;
124 bool batched() const;
126 Size batch_dim() const;
128 constexpr Size base_dim() const { return D; }
132 Size batch_size(Size d) const;
136 Size base_size(Size d) const;
138 Size base_storage() const;
140
142 // These methods mirror TensorBase
149 void batch_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor & other);
153
155 template <typename T, typename = std::enable_if_t<std::is_base_of_v<TensorBase<T>, T>>>
157
159 const std::array<const LabeledAxis *, D> & axes() const { return _axes; }
161 const LabeledAxis & axis(Size i = 0) const { return *_axes[i]; }
162
163protected:
166
168 std::array<const LabeledAxis *, D> _axes;
169
170private:
172 TensorShape storage_shape(indexing::TensorLabelsRef) const;
173
175 indexing::TensorIndices labels_to_indices(indexing::TensorLabelsRef) const;
176};
177
178template <class Derived, Size D>
179template <typename T, typename>
180T
182{
183 return base_index(indices).base_reshape(T::const_base_sizes);
184}
185} // namespace neml2
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:56
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
The primary data structure in NEML2 for working with labeled tensor views.
Definition LabeledTensor.h:44
Derived clone(torch::MemoryFormat memory_format=torch::MemoryFormat::Contiguous) const
Definition LabeledTensor.cxx:127
LabeledTensor()=default
Default constructor.
void zero_()
Set all entries to zero.
Definition LabeledTensor.cxx:162
bool requires_grad() const
Get the requires_grad property.
Definition LabeledTensor.cxx:169
void requires_grad_(bool req=true)
Set the requires_grad property.
Definition LabeledTensor.cxx:176
Size batch_size(Size d) const
Return the length of some batch axis.
Definition LabeledTensor.cxx:253
Derived batch_index(indexing::TensorIndicesRef indices) const
Definition LabeledTensor.cxx:281
const Tensor & tensor() const
Definition LabeledTensor.h:81
void base_index_put_(indexing::TensorLabelsRef labels, const Tensor &other)
Set values by slicing on the base dimensions.
Definition LabeledTensor.cxx:303
torch::Dtype scalar_type() const
Tensor options.
Definition LabeledTensor.cxx:197
Tensor base_index(indexing::TensorLabelsRef labels) const
Get a tensor by slicing on the base dimensions.
Definition LabeledTensor.cxx:288
Size base_storage() const
Return the flattened storage needed just for the base indices.
Definition LabeledTensor.cxx:274
bool batched() const
Whether the tensor is batched.
Definition LabeledTensor.cxx:232
Size batch_dim() const
Return the number of batch dimensions.
Definition LabeledTensor.cxx:239
Derived detach() const
Return a copy without gradient graphs.
Definition LabeledTensor.cxx:134
Size dim() const
Number of tensor dimensions.
Definition LabeledTensor.cxx:211
TensorShapeRef sizes() const
Tensor shape.
Definition LabeledTensor.cxx:218
T reinterpret(indexing::TensorLabelsRef indices) const
Get a tensor by slicing on the base dimensions AND reinterpret it as a primitive tensor.
Definition LabeledTensor.h:181
operator torch::Tensor() const
Definition LabeledTensor.cxx:90
Derived operator-() const
Negation.
Definition LabeledTensor.cxx:183
Size base_size(Size d) const
Return the length of some base axis.
Definition LabeledTensor.cxx:267
TensorShapeRef batch_sizes() const
Return the batch size.
Definition LabeledTensor.cxx:246
static Derived zeros(TensorShapeRef batch_shape, const std::array< const LabeledAxis *, D > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new storage with zeros.
Definition LabeledTensor.cxx:112
torch::Device device() const
Tensor options.
Definition LabeledTensor.cxx:204
torch::TensorOptions options() const
Definition LabeledTensor.cxx:190
Derived to(const torch::TensorOptions &options) const
Change tensor options.
Definition LabeledTensor.cxx:148
const LabeledAxis & axis(Size i=0) const
Get a specific labeled axis.
Definition LabeledTensor.h:161
static Derived empty(TensorShapeRef batch_shape, const std::array< const LabeledAxis *, D > &axes, const torch::TensorOptions &options=default_tensor_options())
Setup new empty storage.
Definition LabeledTensor.cxx:97
constexpr Size base_dim() const
Return the number of base dimensions.
Definition LabeledTensor.h:128
void copy_(const torch::Tensor &other)
Copy another tensor.
Definition LabeledTensor.cxx:155
Tensor _tensor
The tensor.
Definition LabeledTensor.h:165
void detach_()
Detach from gradient graphs.
Definition LabeledTensor.cxx:141
void batch_index_put_(indexing::TensorIndicesRef indices, const torch::Tensor &other)
Set values by slicing on the batch dimensions.
Definition LabeledTensor.cxx:295
Size size(Size dim) const
Tensor shape.
Definition LabeledTensor.cxx:225
TensorShapeRef base_sizes() const
Return the base size.
Definition LabeledTensor.cxx:260
Tensor & tensor()
Definition LabeledTensor.h:82
std::array< const LabeledAxis *, D > _axes
The labeled axes of this tensor.
Definition LabeledTensor.h:168
LabeledTensor< Derived, D > & operator=(const Derived &other)
Assignment operator.
Definition LabeledTensor.cxx:76
const std::array< const LabeledAxis *, D > & axes() const
Get all the labeled axes.
Definition LabeledTensor.h:159
Definition Tensor.h:32
c10::ArrayRef< LabeledAxisAccessor > TensorLabelsRef
Definition LabeledAxisAccessor.h:158
torch::SmallVector< TensorIndex > TensorIndices
Definition types.h:41
torch::ArrayRef< TensorIndex > TensorIndicesRef
Definition types.h:42
Definition CrossRef.cxx:30
torch::TensorOptions & default_tensor_options()
Definition types.cxx:30
torch::SmallVector< Size > TensorShape
Definition types.h:34
int64_t Size
Definition types.h:33
torch::IntArrayRef TensorShapeRef
Definition types.h:35