NEML2 1.4.0
Loading...
Searching...
No Matches
LabeledMatrix.cxx
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#include "neml2/tensors/LabeledMatrix.h"
26#include "neml2/tensors/LabeledVector.h"
27#include "neml2/misc/math.h"
28
29using namespace torch::indexing;
30
31namespace neml2
32{
33LabeledMatrix
35 const LabeledAxis & axis,
36 const torch::TensorOptions & options)
37{
38 return LabeledMatrix(BatchTensor::identity(batch_size, axis.storage_size(), options),
39 {&axis, &axis});
40}
41
42void
44{
45 neml_assert_dbg(axis(1) == other.axis(1), "Can only accumulate matrices with conformal y axes");
46 const auto indices0 = axis(0).common_indices(other.axis(0), recursive);
47 for (const auto & [idxi, idxi_other] : indices0)
48 _tensor.base_index({idxi}) += other.base_index({idxi_other});
49}
50
51void
52LabeledMatrix::fill(const LabeledMatrix & other, bool recursive)
53{
54 neml_assert_dbg(axis(1) == other.axis(1), "Can only accumulate matrices with conformal y axes");
55 const auto indices0 = axis(0).common_indices(other.axis(0), recursive);
56 for (const auto & [idxi, idxi_other] : indices0)
57 _tensor.base_index_put({idxi}, other.base_index({idxi_other}));
58}
59
60LabeledMatrix
61LabeledMatrix::chain(const LabeledMatrix & other) const
62{
63 // This function expresses a chain rule, which is just a dot product between the values of this
64 // and the values of the input The main annoyance is just getting the names correct
65
66 // Check that we are conformal
67 neml_assert_dbg(batch_sizes() == other.batch_sizes(),
68 "LabeledMatrix batch sizes are not the same");
69 neml_assert_dbg(axis(1) == other.axis(0), "Labels are not conformal");
70
71 // If all the sizes are correct then executing the chain rule is pretty easy
72 return LabeledMatrix(math::bmm(*this, other), {&axis(0), &other.axis(1)});
73}
74
76LabeledMatrix::inverse() const
77{
78 neml_assert_dbg(axis(0).storage_size() == axis(1).storage_size(),
79 "Can only invert square derivatives");
80
81 return LabeledMatrix(math::linalg::inv(tensor()), {&axis(1), &axis(0)});
82}
83} // namespace neml2
BatchTensor base_index(const TorchSlice &indices) const
Return an index sliced on the base dimensions.
Definition BatchTensorBase.cxx:193
static BatchTensor identity(TorchSize n, const torch::TensorOptions &options=default_tensor_options())
Unbatched identity tensor.
Definition BatchTensor.cxx:91
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
A labeled axis used to associate layout of a tensor with human-interpretable names.
Definition LabeledAxis.h:55
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
static LabeledMatrix identity(TorchShapeRef batch_size, const LabeledAxis &axis, const torch::TensorOptions &options=default_tensor_options())
Create a labeled identity tensor.
Definition LabeledMatrix.cxx:34
void accumulate(const LabeledMatrix &other, bool recursive=true)
Definition LabeledMatrix.cxx:43
const LabeledAxis & axis(TorchSize i=0) const
Get a specific labeled axis.
Definition LabeledTensor.h:130
torch::TensorOptions options() const
Get the tensor options.
Definition LabeledTensor.h:112
BatchTensor _tensor
The tensor.
Definition LabeledTensor.h:215
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
torch::IntArrayRef TorchShapeRef
Definition types.h:37