File indexing completed on 2024-04-06 12:04:41
0001 #ifndef DataFormat_Math_choleskyInversion_h
0002 #define DataFormat_Math_choleskyInversion_h
0003
0004 #include <cmath>
0005
0006 #include <Eigen/Core>
0007
0008 namespace math {
0009 namespace cholesky {
0010
0011 template <typename M1, typename M2, int N = M2::ColsAtCompileTime>
0012
0013 #ifdef __CUDACC__
0014 __host__ __device__
0015 #endif
0016 inline constexpr void
0017 invertNN(M1 const& src, M2& dst) {
0018
0019
0020
0021 using T = typename M2::Scalar;
0022
0023 T a[N][N];
0024 for (int i = 0; i < N; ++i) {
0025 a[i][i] = src(i, i);
0026 for (int j = i + 1; j < N; ++j)
0027 a[j][i] = src(i, j);
0028 }
0029
0030 for (int j = 0; j < N; ++j) {
0031 a[j][j] = T(1.) / a[j][j];
0032 int jp1 = j + 1;
0033 for (int l = jp1; l < N; ++l) {
0034 a[j][l] = a[j][j] * a[l][j];
0035 T s1 = -a[l][jp1];
0036 for (int i = 0; i < jp1; ++i)
0037 s1 += a[l][i] * a[i][jp1];
0038 a[l][jp1] = -s1;
0039 }
0040 }
0041
0042 if constexpr (N == 1) {
0043 dst(0, 0) = a[0][0];
0044 return;
0045 }
0046 a[0][1] = -a[0][1];
0047 a[1][0] = a[0][1] * a[1][1];
0048 for (int j = 2; j < N; ++j) {
0049 int jm1 = j - 1;
0050 for (int k = 0; k < jm1; ++k) {
0051 T s31 = a[k][j];
0052 for (int i = k; i < jm1; ++i)
0053 s31 += a[k][i + 1] * a[i + 1][j];
0054 a[k][j] = -s31;
0055 a[j][k] = -s31 * a[j][j];
0056 }
0057 a[jm1][j] = -a[jm1][j];
0058 a[j][jm1] = a[jm1][j] * a[j][j];
0059 }
0060
0061 int j = 0;
0062 while (j < N - 1) {
0063 T s33 = a[j][j];
0064 for (int i = j + 1; i < N; ++i)
0065 s33 += a[j][i] * a[i][j];
0066 dst(j, j) = s33;
0067
0068 ++j;
0069 for (int k = 0; k < j; ++k) {
0070 T s32 = 0;
0071 for (int i = j; i < N; ++i)
0072 s32 += a[k][i] * a[i][j];
0073 dst(k, j) = dst(j, k) = s32;
0074 }
0075 }
0076 dst(j, j) = a[j][j];
0077 }
0078
0079
0080
0081
0082
0083
0084
0085
0086
0087
0088
0089
0090
0091 template <typename M1, typename M2>
0092 inline constexpr void __attribute__((always_inline)) invert11(M1 const& src, M2& dst) {
0093 using F = decltype(src(0, 0));
0094 dst(0, 0) = F(1.0) / src(0, 0);
0095 }
0096
0097 template <typename M1, typename M2>
0098 inline constexpr void __attribute__((always_inline)) invert22(M1 const& src, M2& dst) {
0099 using F = decltype(src(0, 0));
0100 auto luc0 = F(1.0) / src(0, 0);
0101 auto luc1 = src(1, 0) * src(1, 0) * luc0;
0102 auto luc2 = F(1.0) / (src(1, 1) - luc1);
0103
0104 auto li21 = luc1 * luc0 * luc2;
0105
0106 dst(0, 0) = li21 + luc0;
0107 dst(1, 0) = -src(1, 0) * luc0 * luc2;
0108 dst(1, 1) = luc2;
0109 }
0110
0111 template <typename M1, typename M2>
0112 inline constexpr void __attribute__((always_inline)) invert33(M1 const& src, M2& dst) {
0113 using F = decltype(src(0, 0));
0114 auto luc0 = F(1.0) / src(0, 0);
0115 auto luc1 = src(1, 0);
0116 auto luc2 = src(1, 1) - luc0 * luc1 * luc1;
0117 luc2 = F(1.0) / luc2;
0118 auto luc3 = src(2, 0);
0119 auto luc4 = (src(2, 1) - luc0 * luc1 * luc3);
0120 auto luc5 = src(2, 2) - (luc0 * luc3 * luc3 + (luc2 * luc4) * luc4);
0121 luc5 = F(1.0) / luc5;
0122
0123 auto li21 = -luc0 * luc1;
0124 auto li32 = -(luc2 * luc4);
0125 auto li31 = (luc1 * (luc2 * luc4) - luc3) * luc0;
0126
0127 dst(0, 0) = luc5 * li31 * li31 + li21 * li21 * luc2 + luc0;
0128 dst(1, 0) = luc5 * li31 * li32 + li21 * luc2;
0129 dst(1, 1) = luc5 * li32 * li32 + luc2;
0130 dst(2, 0) = luc5 * li31;
0131 dst(2, 1) = luc5 * li32;
0132 dst(2, 2) = luc5;
0133 }
0134
0135 template <typename M1, typename M2>
0136 inline constexpr void __attribute__((always_inline)) invert44(M1 const& src, M2& dst) {
0137 using F = decltype(src(0, 0));
0138 auto luc0 = F(1.0) / src(0, 0);
0139 auto luc1 = src(1, 0);
0140 auto luc2 = src(1, 1) - luc0 * luc1 * luc1;
0141 luc2 = F(1.0) / luc2;
0142 auto luc3 = src(2, 0);
0143 auto luc4 = (src(2, 1) - luc0 * luc1 * luc3);
0144 auto luc5 = src(2, 2) - (luc0 * luc3 * luc3 + luc2 * luc4 * luc4);
0145 luc5 = F(1.0) / luc5;
0146 auto luc6 = src(3, 0);
0147 auto luc7 = (src(3, 1) - luc0 * luc1 * luc6);
0148 auto luc8 = (src(3, 2) - luc0 * luc3 * luc6 - luc2 * luc4 * luc7);
0149 auto luc9 = src(3, 3) - (luc0 * luc6 * luc6 + luc2 * luc7 * luc7 + luc8 * (luc8 * luc5));
0150 luc9 = F(1.0) / luc9;
0151
0152 auto li21 = -luc1 * luc0;
0153 auto li32 = -luc2 * luc4;
0154 auto li31 = (luc1 * (luc2 * luc4) - luc3) * luc0;
0155 auto li43 = -(luc8 * luc5);
0156 auto li42 = (luc4 * luc8 * luc5 - luc7) * luc2;
0157 auto li41 = (-luc1 * (luc2 * luc4) * (luc8 * luc5) + luc1 * (luc2 * luc7) + luc3 * (luc8 * luc5) - luc6) * luc0;
0158
0159 dst(0, 0) = luc9 * li41 * li41 + luc5 * li31 * li31 + luc2 * li21 * li21 + luc0;
0160 dst(1, 0) = luc9 * li41 * li42 + luc5 * li31 * li32 + luc2 * li21;
0161 dst(1, 1) = luc9 * li42 * li42 + luc5 * li32 * li32 + luc2;
0162 dst(2, 0) = luc9 * li41 * li43 + luc5 * li31;
0163 dst(2, 1) = luc9 * li42 * li43 + luc5 * li32;
0164 dst(2, 2) = luc9 * li43 * li43 + luc5;
0165 dst(3, 0) = luc9 * li41;
0166 dst(3, 1) = luc9 * li42;
0167 dst(3, 2) = luc9 * li43;
0168 dst(3, 3) = luc9;
0169 }
0170
0171 template <typename M1, typename M2>
0172 inline constexpr void __attribute__((always_inline)) invert55(M1 const& src, M2& dst) {
0173 using F = decltype(src(0, 0));
0174 auto luc0 = F(1.0) / src(0, 0);
0175 auto luc1 = src(1, 0);
0176 auto luc2 = src(1, 1) - luc0 * luc1 * luc1;
0177 luc2 = F(1.0) / luc2;
0178 auto luc3 = src(2, 0);
0179 auto luc4 = (src(2, 1) - luc0 * luc1 * luc3);
0180 auto luc5 = src(2, 2) - (luc0 * luc3 * luc3 + luc2 * luc4 * luc4);
0181 luc5 = F(1.0) / luc5;
0182 auto luc6 = src(3, 0);
0183 auto luc7 = (src(3, 1) - luc0 * luc1 * luc6);
0184 auto luc8 = (src(3, 2) - luc0 * luc3 * luc6 - luc2 * luc4 * luc7);
0185 auto luc9 = src(3, 3) - (luc0 * luc6 * luc6 + luc2 * luc7 * luc7 + luc8 * (luc8 * luc5));
0186 luc9 = F(1.0) / luc9;
0187 auto luc10 = src(4, 0);
0188 auto luc11 = (src(4, 1) - luc0 * luc1 * luc10);
0189 auto luc12 = (src(4, 2) - luc0 * luc3 * luc10 - luc2 * luc4 * luc11);
0190 auto luc13 = (src(4, 3) - luc0 * luc6 * luc10 - luc2 * luc7 * luc11 - luc5 * luc8 * luc12);
0191 auto luc14 =
0192 src(4, 4) - (luc0 * luc10 * luc10 + luc2 * luc11 * luc11 + luc5 * luc12 * luc12 + luc9 * luc13 * luc13);
0193 luc14 = F(1.0) / luc14;
0194
0195 auto li21 = -luc1 * luc0;
0196 auto li32 = -luc2 * luc4;
0197 auto li31 = (luc1 * (luc2 * luc4) - luc3) * luc0;
0198 auto li43 = -(luc8 * luc5);
0199 auto li42 = (luc4 * luc8 * luc5 - luc7) * luc2;
0200 auto li41 = (-luc1 * (luc2 * luc4) * (luc8 * luc5) + luc1 * (luc2 * luc7) + luc3 * (luc8 * luc5) - luc6) * luc0;
0201 auto li54 = -luc13 * luc9;
0202 auto li53 = (luc13 * luc8 * luc9 - luc12) * luc5;
0203 auto li52 = (-luc4 * luc8 * luc13 * luc5 * luc9 + luc4 * luc12 * luc5 + luc7 * luc13 * luc9 - luc11) * luc2;
0204 auto li51 = (luc1 * luc4 * luc8 * luc13 * luc2 * luc5 * luc9 - luc13 * luc8 * luc3 * luc9 * luc5 -
0205 luc12 * luc4 * luc1 * luc2 * luc5 - luc13 * luc7 * luc1 * luc9 * luc2 + luc11 * luc1 * luc2 +
0206 luc12 * luc3 * luc5 + luc13 * luc6 * luc9 - luc10) *
0207 luc0;
0208
0209 dst(0, 0) = luc14 * li51 * li51 + luc9 * li41 * li41 + luc5 * li31 * li31 + luc2 * li21 * li21 + luc0;
0210 dst(1, 0) = luc14 * li51 * li52 + luc9 * li41 * li42 + luc5 * li31 * li32 + luc2 * li21;
0211 dst(1, 1) = luc14 * li52 * li52 + luc9 * li42 * li42 + luc5 * li32 * li32 + luc2;
0212 dst(2, 0) = luc14 * li51 * li53 + luc9 * li41 * li43 + luc5 * li31;
0213 dst(2, 1) = luc14 * li52 * li53 + luc9 * li42 * li43 + luc5 * li32;
0214 dst(2, 2) = luc14 * li53 * li53 + luc9 * li43 * li43 + luc5;
0215 dst(3, 0) = luc14 * li51 * li54 + luc9 * li41;
0216 dst(3, 1) = luc14 * li52 * li54 + luc9 * li42;
0217 dst(3, 2) = luc14 * li53 * li54 + luc9 * li43;
0218 dst(3, 3) = luc14 * li54 * li54 + luc9;
0219 dst(4, 0) = luc14 * li51;
0220 dst(4, 1) = luc14 * li52;
0221 dst(4, 2) = luc14 * li53;
0222 dst(4, 3) = luc14 * li54;
0223 dst(4, 4) = luc14;
0224 }
0225
0226 template <typename M1, typename M2>
0227 inline constexpr void __attribute__((always_inline)) invert66(M1 const& src, M2& dst) {
0228 using F = decltype(src(0, 0));
0229 auto luc0 = F(1.0) / src(0, 0);
0230 auto luc1 = src(1, 0);
0231 auto luc2 = src(1, 1) - luc0 * luc1 * luc1;
0232 luc2 = F(1.0) / luc2;
0233 auto luc3 = src(2, 0);
0234 auto luc4 = (src(2, 1) - luc0 * luc1 * luc3);
0235 auto luc5 = src(2, 2) - (luc0 * luc3 * luc3 + luc2 * luc4 * luc4);
0236 luc5 = F(1.0) / luc5;
0237 auto luc6 = src(3, 0);
0238 auto luc7 = (src(3, 1) - luc0 * luc1 * luc6);
0239 auto luc8 = (src(3, 2) - luc0 * luc3 * luc6 - luc2 * luc4 * luc7);
0240 auto luc9 = src(3, 3) - (luc0 * luc6 * luc6 + luc2 * luc7 * luc7 + luc8 * (luc8 * luc5));
0241 luc9 = F(1.0) / luc9;
0242 auto luc10 = src(4, 0);
0243 auto luc11 = (src(4, 1) - luc0 * luc1 * luc10);
0244 auto luc12 = (src(4, 2) - luc0 * luc3 * luc10 - luc2 * luc4 * luc11);
0245 auto luc13 = (src(4, 3) - luc0 * luc6 * luc10 - luc2 * luc7 * luc11 - luc5 * luc8 * luc12);
0246 auto luc14 =
0247 src(4, 4) - (luc0 * luc10 * luc10 + luc2 * luc11 * luc11 + luc5 * luc12 * luc12 + luc9 * luc13 * luc13);
0248 luc14 = F(1.0) / luc14;
0249 auto luc15 = src(5, 0);
0250 auto luc16 = (src(5, 1) - luc0 * luc1 * luc15);
0251 auto luc17 = (src(5, 2) - luc0 * luc3 * luc15 - luc2 * luc4 * luc16);
0252 auto luc18 = (src(5, 3) - luc0 * luc6 * luc15 - luc2 * luc7 * luc16 - luc5 * luc8 * luc17);
0253 auto luc19 =
0254 (src(5, 4) - luc0 * luc10 * luc15 - luc2 * luc11 * luc16 - luc5 * luc12 * luc17 - luc9 * luc13 * luc18);
0255 auto luc20 = src(5, 5) - (luc0 * luc15 * luc15 + luc2 * luc16 * luc16 + luc5 * luc17 * luc17 +
0256 luc9 * luc18 * luc18 + luc14 * luc19 * luc19);
0257 luc20 = F(1.0) / luc20;
0258
0259 auto li21 = -luc1 * luc0;
0260 auto li32 = -luc2 * luc4;
0261 auto li31 = (luc1 * (luc2 * luc4) - luc3) * luc0;
0262 auto li43 = -(luc8 * luc5);
0263 auto li42 = (luc4 * luc8 * luc5 - luc7) * luc2;
0264 auto li41 = (-luc1 * (luc2 * luc4) * (luc8 * luc5) + luc1 * (luc2 * luc7) + luc3 * (luc8 * luc5) - luc6) * luc0;
0265 auto li54 = -luc13 * luc9;
0266 auto li53 = (luc13 * luc8 * luc9 - luc12) * luc5;
0267 auto li52 = (-luc4 * luc8 * luc13 * luc5 * luc9 + luc4 * luc12 * luc5 + luc7 * luc13 * luc9 - luc11) * luc2;
0268 auto li51 = (luc1 * luc4 * luc8 * luc13 * luc2 * luc5 * luc9 - luc13 * luc8 * luc3 * luc9 * luc5 -
0269 luc12 * luc4 * luc1 * luc2 * luc5 - luc13 * luc7 * luc1 * luc9 * luc2 + luc11 * luc1 * luc2 +
0270 luc12 * luc3 * luc5 + luc13 * luc6 * luc9 - luc10) *
0271 luc0;
0272
0273 auto li65 = -luc19 * luc14;
0274 auto li64 = (luc19 * luc14 * luc13 - luc18) * luc9;
0275 auto li63 =
0276 (-luc8 * luc13 * (luc19 * luc14) * luc9 + luc8 * luc9 * luc18 + luc12 * (luc19 * luc14) - luc17) * luc5;
0277 auto li62 = (luc4 * (luc8 * luc9) * luc13 * luc5 * (luc19 * luc14) - luc18 * luc4 * (luc8 * luc9) * luc5 -
0278 luc19 * luc12 * luc4 * luc14 * luc5 - luc19 * luc13 * luc7 * luc14 * luc9 + luc17 * luc4 * luc5 +
0279 luc18 * luc7 * luc9 + luc19 * luc11 * luc14 - luc16) *
0280 luc2;
0281 auto li61 =
0282 (-luc19 * luc13 * luc8 * luc4 * luc1 * luc2 * luc5 * luc9 * luc14 +
0283 luc18 * luc8 * luc4 * luc1 * luc2 * luc5 * luc9 + luc19 * luc12 * luc4 * luc1 * luc2 * luc5 * luc14 +
0284 luc19 * luc13 * luc7 * luc1 * luc2 * luc9 * luc14 + luc19 * luc13 * luc8 * luc3 * luc5 * luc9 * luc14 -
0285 luc17 * luc4 * luc1 * luc2 * luc5 - luc18 * luc7 * luc1 * luc2 * luc9 - luc19 * luc11 * luc1 * luc2 * luc14 -
0286 luc18 * luc8 * luc3 * luc5 * luc9 - luc19 * luc12 * luc3 * luc5 * luc14 -
0287 luc19 * luc13 * luc6 * luc9 * luc14 + luc16 * luc1 * luc2 + luc17 * luc3 * luc5 + luc18 * luc6 * luc9 +
0288 luc19 * luc10 * luc14 - luc15) *
0289 luc0;
0290
0291 dst(0, 0) = luc20 * li61 * li61 + luc14 * li51 * li51 + luc9 * li41 * li41 + luc5 * li31 * li31 +
0292 luc2 * li21 * li21 + luc0;
0293 dst(1, 0) = luc20 * li61 * li62 + luc14 * li51 * li52 + luc9 * li41 * li42 + luc5 * li31 * li32 + luc2 * li21;
0294 dst(1, 1) = luc20 * li62 * li62 + luc14 * li52 * li52 + luc9 * li42 * li42 + luc5 * li32 * li32 + luc2;
0295 dst(2, 0) = luc20 * li61 * li63 + luc14 * li51 * li53 + luc9 * li41 * li43 + luc5 * li31;
0296 dst(2, 1) = luc20 * li62 * li63 + luc14 * li52 * li53 + luc9 * li42 * li43 + luc5 * li32;
0297 dst(2, 2) = luc20 * li63 * li63 + luc14 * li53 * li53 + luc9 * li43 * li43 + luc5;
0298 dst(3, 0) = luc20 * li61 * li64 + luc14 * li51 * li54 + luc9 * li41;
0299 dst(3, 1) = luc20 * li62 * li64 + luc14 * li52 * li54 + luc9 * li42;
0300 dst(3, 2) = luc20 * li63 * li64 + luc14 * li53 * li54 + luc9 * li43;
0301 dst(3, 3) = luc20 * li64 * li64 + luc14 * li54 * li54 + luc9;
0302 dst(4, 0) = luc20 * li61 * li65 + luc14 * li51;
0303 dst(4, 1) = luc20 * li62 * li65 + luc14 * li52;
0304 dst(4, 2) = luc20 * li63 * li65 + luc14 * li53;
0305 dst(4, 3) = luc20 * li64 * li65 + luc14 * li54;
0306 dst(4, 4) = luc20 * li65 * li65 + luc14;
0307 dst(5, 0) = luc20 * li61;
0308 dst(5, 1) = luc20 * li62;
0309 dst(5, 2) = luc20 * li63;
0310 dst(5, 3) = luc20 * li64;
0311 dst(5, 4) = luc20 * li65;
0312 dst(5, 5) = luc20;
0313 }
0314
0315 template <typename M>
0316 inline constexpr void symmetrize11(M& dst) {}
0317
0318 template <typename M>
0319 inline constexpr void symmetrize22(M& dst) {
0320 dst(0, 1) = dst(1, 0);
0321 }
0322
0323 template <typename M>
0324 inline constexpr void symmetrize33(M& dst) {
0325 symmetrize22(dst);
0326 dst(0, 2) = dst(2, 0);
0327 dst(1, 2) = dst(2, 1);
0328 }
0329
0330 template <typename M>
0331 inline constexpr void symmetrize44(M& dst) {
0332 symmetrize33(dst);
0333 dst(0, 3) = dst(3, 0);
0334 dst(1, 3) = dst(3, 1);
0335 dst(2, 3) = dst(3, 2);
0336 }
0337
0338 template <typename M>
0339 inline constexpr void symmetrize55(M& dst) {
0340 symmetrize44(dst);
0341 dst(0, 4) = dst(4, 0);
0342 dst(1, 4) = dst(4, 1);
0343 dst(2, 4) = dst(4, 2);
0344 dst(3, 4) = dst(4, 3);
0345 }
0346
0347 template <typename M>
0348 inline constexpr void symmetrize66(M& dst) {
0349 symmetrize55(dst);
0350 dst(0, 5) = dst(5, 0);
0351 dst(1, 5) = dst(5, 1);
0352 dst(2, 5) = dst(5, 2);
0353 dst(3, 5) = dst(5, 3);
0354 dst(4, 5) = dst(5, 4);
0355 }
0356
0357 template <typename M1, typename M2, int N>
0358 struct Inverter {
0359 static constexpr void eval(M1 const& src, M2& dst) { dst = src.inverse(); }
0360 };
0361
0362 template <typename M1, typename M2>
0363 struct Inverter<M1, M2, 1> {
0364 static constexpr void eval(M1 const& src, M2& dst) { invert11(src, dst); }
0365 };
0366
0367 template <typename M1, typename M2>
0368 struct Inverter<M1, M2, 2> {
0369 static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
0370 invert22(src, dst);
0371 symmetrize22(dst);
0372 }
0373 };
0374
0375 template <typename M1, typename M2>
0376 struct Inverter<M1, M2, 3> {
0377 static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
0378 invert33(src, dst);
0379 symmetrize33(dst);
0380 }
0381 };
0382
0383 template <typename M1, typename M2>
0384 struct Inverter<M1, M2, 4> {
0385 static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
0386 invert44(src, dst);
0387 symmetrize44(dst);
0388 }
0389 };
0390
0391 template <typename M1, typename M2>
0392 struct Inverter<M1, M2, 5> {
0393 static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
0394 invert55(src, dst);
0395 symmetrize55(dst);
0396 }
0397 };
0398
0399 template <typename M1, typename M2>
0400 struct Inverter<M1, M2, 6> {
0401 static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
0402 invert66(src, dst);
0403 symmetrize66(dst);
0404 }
0405 };
0406
0407
0408 template <typename M1, typename M2>
0409 inline constexpr void __attribute__((always_inline)) invert(M1 const& src, M2& dst) {
0410 if constexpr (M2::ColsAtCompileTime < 7)
0411 Inverter<M1, M2, M2::ColsAtCompileTime>::eval(src, dst);
0412 else
0413 invertNN(src, dst);
0414 }
0415
0416 }
0417 }
0418
0419 #endif