NEML2 1.4.0
Loading...
Searching...
No Matches
LabeledMatrix.cxx
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#include "neml2/tensors/LabeledMatrix.h"
26#include "neml2/tensors/LabeledVector.h"
27
28using namespace torch::indexing;
29
30namespace neml2
31{
32LabeledMatrix
34 const LabeledAxis & axis,
35 const torch::TensorOptions & options)
36{
37 return LabeledMatrix(BatchTensor::identity(batch_size, axis.storage_size(), options),
38 {&axis, &axis});
39}
40
41void
43{
44 neml_assert_dbg(axis(1) == other.axis(1), "Can only accumulate matrices with conformal y axes");
45 const auto indices0 = axis(0).common_indices(other.axis(0), recursive);
46 for (const auto & [idxi, idxi_other] : indices0)
47 _tensor.base_index({idxi}) += other.base_index({idxi_other});
48}
49
50void
51LabeledMatrix::fill(const LabeledMatrix & other, bool recursive)
52{
53 neml_assert_dbg(axis(1) == other.axis(1), "Can only accumulate matrices with conformal y axes");
54 const auto indices0 = axis(0).common_indices(other.axis(0), recursive);
55 for (const auto & [idxi, idxi_other] : indices0)
56 _tensor.base_index_put({idxi}, other.base_index({idxi_other}));
57}
58
59LabeledMatrix
60LabeledMatrix::chain(const LabeledMatrix & other) const
61{
62 // This function expresses a chain rule, which is just a dot product between the values of this
63 // and the values of the input The main annoyance is just getting the names correct
64
65 // Check that we are conformal
66 neml_assert_dbg(batch_sizes() == other.batch_sizes(),
67 "LabeledMatrix batch sizes are not the same");
68 neml_assert_dbg(axis(1) == other.axis(0), "Labels are not conformal");
69
70 // If all the sizes are correct then executing the chain rule is pretty easy
71 return LabeledMatrix(math::bmm(*this, other), {&axis(0), &other.axis(1)});
72}
73
75LabeledMatrix::inverse() const
76{
77 neml_assert_dbg(axis(0).storage_size() == axis(1).storage_size(),
78 "Can only invert square derivatives");
79
80 return LabeledMatrix(BatchTensor(torch::linalg::inv(tensor()), batch_dim()),
81 {&axis(1), &axis(0)});
82}
83} // namespace neml2
BatchTensor base_index(const TorchSlice &indices) const
Return an index sliced on the base dimensions.
Definition BatchTensorBase.cxx:193
Definition BatchTensor.h:32
static BatchTensor identity(TorchSize n, const torch::TensorOptions &options=default_tensor_options())
Unbatched identity tensor.
Definition BatchTensor.cxx:91
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
static LabeledMatrix identity(TorchShapeRef batch_size, const LabeledAxis &axis, const torch::TensorOptions &options=default_tensor_options())
Create a labeled identity tensor.
Definition LabeledMatrix.cxx:33
void accumulate(const LabeledMatrix &other, bool recursive=true)
Definition LabeledMatrix.cxx:42
const LabeledAxis & axis(TorchSize i=0) const
Get a specific labeled axis.
Definition LabeledTensor.h:130
torch::TensorOptions options() const
Get the tensor options.
Definition LabeledTensor.h:112
BatchTensor _tensor
The tensor.
Definition LabeledTensor.h:215
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
torch::IntArrayRef TorchShapeRef
Definition types.h:35