NEML2 1.4.0
Loading...
Searching...
No Matches
LabeledTensor3D.cxx
1// Copyright 2023, UChicago Argonne, LLC
2// All Rights Reserved
3// Software Name: NEML2 -- the New Engineering material Model Library, version 2
4// By: Argonne National Laboratory
5// OPEN SOURCE LICENSE (MIT)
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25#include "neml2/tensors/LabeledTensor3D.h"
26#include "neml2/tensors/LabeledMatrix.h"
27
28using namespace torch::indexing;
29
30namespace neml2
31{
32void
34{
35 neml_assert_dbg(axis(1) == other.axis(1), "Can only accumulate 3D tensors with conformal y axes");
36 neml_assert_dbg(axis(2) == other.axis(2), "Can only accumulate 3D tensors with conformal z axes");
37 const auto indices0 = axis(0).common_indices(other.axis(0), recursive);
38 for (const auto & [idxi, idxi_other] : indices0)
39 _tensor.base_index({idxi}) += other.base_index({idxi_other});
40}
41
42void
43LabeledTensor3D::fill(const LabeledTensor3D & other, bool recursive)
44{
45 neml_assert_dbg(axis(1) == other.axis(1), "Can only accumulate 3D tensors with conformal y axes");
46 neml_assert_dbg(axis(2) == other.axis(2), "Can only accumulate 3D tensors with conformal z axes");
47 const auto indices0 = axis(0).common_indices(other.axis(0), recursive);
48 for (const auto & [idxi, idxi_other] : indices0)
49 _tensor.base_index_put({idxi}, other.base_index({idxi_other}));
50}
51
52LabeledTensor3D
53LabeledTensor3D::chain(const LabeledTensor3D & other,
54 const LabeledMatrix & dself,
55 const LabeledMatrix & dother) const
56{
57 // This function expresses the second oreder chain rule, which can be expressed as
58 // d2y/dx2 = d2y/du2 du/dx du/dx + dy/du d2u/dx2
59 // In index notation this is
60 // (d2y/dx2)_{ijk} = (d2y/du2)_{ipq} (du/dx)_{pj} (du/dx)_{qk} + (dy/du)_{ip} (d2u/dx2)_{pjk}
61
62 // Make sure we are conformal
63 neml_assert_dbg(batch_sizes() == other.batch_sizes(), "Batch sizes are not the same");
64 neml_assert_dbg(batch_sizes() == dself.batch_sizes(), "Batch sizes are not the same");
65 neml_assert_dbg(batch_sizes() == dother.batch_sizes(), "Batch sizes are not the same");
66 neml_assert_dbg(axis(1) == axis(2), "Self labels are not conformal");
67 neml_assert_dbg(other.axis(1) == other.axis(2), "Other labels are not conformal");
68 neml_assert_dbg(axis(2) == other.axis(0), "Self and other labels are not conformal");
69
70 // If all the sizes are correct then executing the chain rule is pretty easy
71 return LabeledTensor3D(torch::einsum("...ipq,...pj,...qk", {*this, dother, dother}) +
72 torch::einsum("...ip,...pjk", {dself, other}),
74 {&axis(0), &other.axis(1), &other.axis(2)});
75}
76} // namespace neml2
BatchTensor base_index(const TorchSlice &indices) const
Return an index sliced on the base dimensions.
Definition BatchTensorBase.cxx:193
The wrapper (decorator) for cross-referencing unresolved values at parse time.
Definition CrossRef.h:52
A single-batched, logically 2D LabeledTensor.
Definition LabeledMatrix.h:38
A single-batched, logically 3D LabeledTensor.
Definition LabeledTensor3D.h:38
void accumulate(const LabeledTensor3D &other, bool recursive=true)
Definition LabeledTensor3D.cxx:33
const LabeledAxis & axis(TorchSize i=0) const
Get a specific labeled axis.
Definition LabeledTensor.h:130
BatchTensor _tensor
The tensor.
Definition LabeledTensor.h:215
Definition CrossRef.cxx:32
void neml_assert_dbg(bool assertion, Args &&... args)
Definition error.h:85
TorchSize broadcast_batch_dim(T &&...)
The batch dimension after broadcasting.