Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:04:41

0001 #ifndef DataFormat_Math_choleskyInversion_h
0002 #define DataFormat_Math_choleskyInversion_h
0003 
0004 #include <cmath>
0005 
0006 #include <Eigen/Core>
0007 
0008 namespace math {
0009   namespace cholesky {
0010 
0011     template <typename M1, typename M2, int N = M2::ColsAtCompileTime>
0012 // without this: either does not compile or compiles and then fails silently at runtime
0013 #ifdef __CUDACC__
0014     __host__ __device__
0015 #endif
0016         inline constexpr void
0017         invertNN(M1 const& src, M2& dst) {
0018 
0019       // origin: CERNLIB
0020 
0021       using T = typename M2::Scalar;
0022 
0023       T a[N][N];
0024       for (int i = 0; i < N; ++i) {
0025         a[i][i] = src(i, i);
0026         for (int j = i + 1; j < N; ++j)
0027           a[j][i] = src(i, j);
0028       }
0029 
0030       for (int j = 0; j < N; ++j) {
0031         a[j][j] = T(1.) / a[j][j];
0032         int jp1 = j + 1;
0033         for (int l = jp1; l < N; ++l) {
0034           a[j][l] = a[j][j] * a[l][j];
0035           T s1 = -a[l][jp1];
0036           for (int i = 0; i < jp1; ++i)
0037             s1 += a[l][i] * a[i][jp1];
0038           a[l][jp1] = -s1;
0039         }
0040       }
0041 
0042       if constexpr (N == 1) {
0043         dst(0, 0) = a[0][0];
0044         return;
0045       }
0046       a[0][1] = -a[0][1];
0047       a[1][0] = a[0][1] * a[1][1];
0048       for (int j = 2; j < N; ++j) {
0049         int jm1 = j - 1;
0050         for (int k = 0; k < jm1; ++k) {
0051           T s31 = a[k][j];
0052           for (int i = k; i < jm1; ++i)
0053             s31 += a[k][i + 1] * a[i + 1][j];
0054           a[k][j] = -s31;
0055           a[j][k] = -s31 * a[j][j];
0056         }
0057         a[jm1][j] = -a[jm1][j];
0058         a[j][jm1] = a[jm1][j] * a[j][j];
0059       }
0060 
0061       int j = 0;
0062       while (j < N - 1) {
0063         T s33 = a[j][j];
0064         for (int i = j + 1; i < N; ++i)
0065           s33 += a[j][i] * a[i][j];
0066         dst(j, j) = s33;
0067 
0068         ++j;
0069         for (int k = 0; k < j; ++k) {
0070           T s32 = 0;
0071           for (int i = j; i < N; ++i)
0072             s32 += a[k][i] * a[i][j];
0073           dst(k, j) = dst(j, k) = s32;
0074         }
0075       }
0076       dst(j, j) = a[j][j];
0077     }
0078 
0079     /**
0080  * fully inlined specialized code to perform the inversion of a
0081  * positive defined matrix of rank up to 6.
0082  *
0083  * adapted from ROOT::Math::CholeskyDecomp
0084  * originally by
0085  * @author Manuel Schiller
0086  * @date Aug 29 2008
0087  *
0088  *
0089  */
0090 
0091     template <typename M1, typename M2>
0092     inline constexpr void __attribute__((always_inline)) invert11(M1 const& src, M2& dst) {
0093       using F = decltype(src(0, 0));
0094       dst(0, 0) = F(1.0) / src(0, 0);
0095     }
0096 
0097     template <typename M1, typename M2>
0098     inline constexpr void __attribute__((always_inline)) invert22(M1 const& src, M2& dst) {
0099       using F = decltype(src(0, 0));
0100       auto luc0 = F(1.0) / src(0, 0);
0101       auto luc1 = src(1, 0) * src(1, 0) * luc0;
0102       auto luc2 = F(1.0) / (src(1, 1) - luc1);
0103 
0104       auto li21 = luc1 * luc0 * luc2;
0105 
0106       dst(0, 0) = li21 + luc0;
0107       dst(1, 0) = -src(1, 0) * luc0 * luc2;
0108       dst(1, 1) = luc2;
0109     }
0110 
0111     template <typename M1, typename M2>
0112     inline constexpr void __attribute__((always_inline)) invert33(M1 const& src, M2& dst) {
0113       using F = decltype(src(0, 0));
0114       auto luc0 = F(1.0) / src(0, 0);
0115       auto luc1 = src(1, 0);
0116       auto luc2 = src(1, 1) - luc0 * luc1 * luc1;
0117       luc2 = F(1.0) / luc2;
0118       auto luc3 = src(2, 0);
0119       auto luc4 = (src(2, 1) - luc0 * luc1 * luc3);
0120       auto luc5 = src(2, 2) - (luc0 * luc3 * luc3 + (luc2 * luc4) * luc4);
0121       luc5 = F(1.0) / luc5;
0122 
0123       auto li21 = -luc0 * luc1;
0124       auto li32 = -(luc2 * luc4);
0125       auto li31 = (luc1 * (luc2 * luc4) - luc3) * luc0;
0126 
0127       dst(0, 0) = luc5 * li31 * li31 + li21 * li21 * luc2 + luc0;
0128       dst(1, 0) = luc5 * li31 * li32 + li21 * luc2;
0129       dst(1, 1) = luc5 * li32 * li32 + luc2;
0130       dst(2, 0) = luc5 * li31;
0131       dst(2, 1) = luc5 * li32;
0132       dst(2, 2) = luc5;
0133     }
0134 
0135     template <typename M1, typename M2>
0136     inline constexpr void __attribute__((always_inline)) invert44(M1 const& src, M2& dst) {
0137       using F = decltype(src(0, 0));
0138       auto luc0 = F(1.0) / src(0, 0);
0139       auto luc1 = src(1, 0);
0140       auto luc2 = src(1, 1) - luc0 * luc1 * luc1;
0141       luc2 = F(1.0) / luc2;
0142       auto luc3 = src(2, 0);
0143       auto luc4 = (src(2, 1) - luc0 * luc1 * luc3);
0144       auto luc5 = src(2, 2) - (luc0 * luc3 * luc3 + luc2 * luc4 * luc4);
0145       luc5 = F(1.0) / luc5;
0146       auto luc6 = src(3, 0);
0147       auto luc7 = (src(3, 1) - luc0 * luc1 * luc6);
0148       auto luc8 = (src(3, 2) - luc0 * luc3 * luc6 - luc2 * luc4 * luc7);
0149       auto luc9 = src(3, 3) - (luc0 * luc6 * luc6 + luc2 * luc7 * luc7 + luc8 * (luc8 * luc5));
0150       luc9 = F(1.0) / luc9;
0151 
0152       auto li21 = -luc1 * luc0;
0153       auto li32 = -luc2 * luc4;
0154       auto li31 = (luc1 * (luc2 * luc4) - luc3) * luc0;
0155       auto li43 = -(luc8 * luc5);
0156       auto li42 = (luc4 * luc8 * luc5 - luc7) * luc2;
0157       auto li41 = (-luc1 * (luc2 * luc4) * (luc8 * luc5) + luc1 * (luc2 * luc7) + luc3 * (luc8 * luc5) - luc6) * luc0;
0158 
0159       dst(0, 0) = luc9 * li41 * li41 + luc5 * li31 * li31 + luc2 * li21 * li21 + luc0;
0160       dst(1, 0) = luc9 * li41 * li42 + luc5 * li31 * li32 + luc2 * li21;
0161       dst(1, 1) = luc9 * li42 * li42 + luc5 * li32 * li32 + luc2;
0162       dst(2, 0) = luc9 * li41 * li43 + luc5 * li31;
0163       dst(2, 1) = luc9 * li42 * li43 + luc5 * li32;
0164       dst(2, 2) = luc9 * li43 * li43 + luc5;
0165       dst(3, 0) = luc9 * li41;
0166       dst(3, 1) = luc9 * li42;
0167       dst(3, 2) = luc9 * li43;
0168       dst(3, 3) = luc9;
0169     }
0170 
0171     template <typename M1, typename M2>
0172     inline constexpr void __attribute__((always_inline)) invert55(M1 const& src, M2& dst) {
0173       using F = decltype(src(0, 0));
0174       auto luc0 = F(1.0) / src(0, 0);
0175       auto luc1 = src(1, 0);
0176       auto luc2 = src(1, 1) - luc0 * luc1 * luc1;
0177       luc2 = F(1.0) / luc2;
0178       auto luc3 = src(2, 0);
0179       auto luc4 = (src(2, 1) - luc0 * luc1 * luc3);
0180       auto luc5 = src(2, 2) - (luc0 * luc3 * luc3 + luc2 * luc4 * luc4);
0181       luc5 = F(1.0) / luc5;
0182       auto luc6 = src(3, 0);
0183       auto luc7 = (src(3, 1) - luc0 * luc1 * luc6);
0184       auto luc8 = (src(3, 2) - luc0 * luc3 * luc6 - luc2 * luc4 * luc7);
0185       auto luc9 = src(3, 3) - (luc0 * luc6 * luc6 + luc2 * luc7 * luc7 + luc8 * (luc8 * luc5));
0186       luc9 = F(1.0) / luc9;
0187       auto luc10 = src(4, 0);
0188       auto luc11 = (src(4, 1) - luc0 * luc1 * luc10);
0189       auto luc12 = (src(4, 2) - luc0 * luc3 * luc10 - luc2 * luc4 * luc11);
0190       auto luc13 = (src(4, 3) - luc0 * luc6 * luc10 - luc2 * luc7 * luc11 - luc5 * luc8 * luc12);
0191       auto luc14 =
0192           src(4, 4) - (luc0 * luc10 * luc10 + luc2 * luc11 * luc11 + luc5 * luc12 * luc12 + luc9 * luc13 * luc13);
0193       luc14 = F(1.0) / luc14;
0194 
0195       auto li21 = -luc1 * luc0;
0196       auto li32 = -luc2 * luc4;
0197       auto li31 = (luc1 * (luc2 * luc4) - luc3) * luc0;
0198       auto li43 = -(luc8 * luc5);
0199       auto li42 = (luc4 * luc8 * luc5 - luc7) * luc2;
0200       auto li41 = (-luc1 * (luc2 * luc4) * (luc8 * luc5) + luc1 * (luc2 * luc7) + luc3 * (luc8 * luc5) - luc6) * luc0;
0201       auto li54 = -luc13 * luc9;
0202       auto li53 = (luc13 * luc8 * luc9 - luc12) * luc5;
0203       auto li52 = (-luc4 * luc8 * luc13 * luc5 * luc9 + luc4 * luc12 * luc5 + luc7 * luc13 * luc9 - luc11) * luc2;
0204       auto li51 = (luc1 * luc4 * luc8 * luc13 * luc2 * luc5 * luc9 - luc13 * luc8 * luc3 * luc9 * luc5 -
0205                    luc12 * luc4 * luc1 * luc2 * luc5 - luc13 * luc7 * luc1 * luc9 * luc2 + luc11 * luc1 * luc2 +
0206                    luc12 * luc3 * luc5 + luc13 * luc6 * luc9 - luc10) *
0207                   luc0;
0208 
0209       dst(0, 0) = luc14 * li51 * li51 + luc9 * li41 * li41 + luc5 * li31 * li31 + luc2 * li21 * li21 + luc0;
0210       dst(1, 0) = luc14 * li51 * li52 + luc9 * li41 * li42 + luc5 * li31 * li32 + luc2 * li21;
0211       dst(1, 1) = luc14 * li52 * li52 + luc9 * li42 * li42 + luc5 * li32 * li32 + luc2;
0212       dst(2, 0) = luc14 * li51 * li53 + luc9 * li41 * li43 + luc5 * li31;
0213       dst(2, 1) = luc14 * li52 * li53 + luc9 * li42 * li43 + luc5 * li32;
0214       dst(2, 2) = luc14 * li53 * li53 + luc9 * li43 * li43 + luc5;
0215       dst(3, 0) = luc14 * li51 * li54 + luc9 * li41;
0216       dst(3, 1) = luc14 * li52 * li54 + luc9 * li42;
0217       dst(3, 2) = luc14 * li53 * li54 + luc9 * li43;
0218       dst(3, 3) = luc14 * li54 * li54 + luc9;
0219       dst(4, 0) = luc14 * li51;
0220       dst(4, 1) = luc14 * li52;
0221       dst(4, 2) = luc14 * li53;
0222       dst(4, 3) = luc14 * li54;
0223       dst(4, 4) = luc14;
0224     }
0225 
0226     template <typename M1, typename M2>
0227     inline constexpr void __attribute__((always_inline)) invert66(M1 const& src, M2& dst) {
0228       using F = decltype(src(0, 0));
0229       auto luc0 = F(1.0) / src(0, 0);
0230       auto luc1 = src(1, 0);
0231       auto luc2 = src(1, 1) - luc0 * luc1 * luc1;
0232       luc2 = F(1.0) / luc2;
0233       auto luc3 = src(2, 0);
0234       auto luc4 = (src(2, 1) - luc0 * luc1 * luc3);
0235       auto luc5 = src(2, 2) - (luc0 * luc3 * luc3 + luc2 * luc4 * luc4);
0236       luc5 = F(1.0) / luc5;
0237       auto luc6 = src(3, 0);
0238       auto luc7 = (src(3, 1) - luc0 * luc1 * luc6);
0239       auto luc8 = (src(3, 2) - luc0 * luc3 * luc6 - luc2 * luc4 * luc7);
0240       auto luc9 = src(3, 3) - (luc0 * luc6 * luc6 + luc2 * luc7 * luc7 + luc8 * (luc8 * luc5));
0241       luc9 = F(1.0) / luc9;
0242       auto luc10 = src(4, 0);
0243       auto luc11 = (src(4, 1) - luc0 * luc1 * luc10);
0244       auto luc12 = (src(4, 2) - luc0 * luc3 * luc10 - luc2 * luc4 * luc11);
0245       auto luc13 = (src(4, 3) - luc0 * luc6 * luc10 - luc2 * luc7 * luc11 - luc5 * luc8 * luc12);
0246       auto luc14 =
0247           src(4, 4) - (luc0 * luc10 * luc10 + luc2 * luc11 * luc11 + luc5 * luc12 * luc12 + luc9 * luc13 * luc13);
0248       luc14 = F(1.0) / luc14;
0249       auto luc15 = src(5, 0);
0250       auto luc16 = (src(5, 1) - luc0 * luc1 * luc15);
0251       auto luc17 = (src(5, 2) - luc0 * luc3 * luc15 - luc2 * luc4 * luc16);
0252       auto luc18 = (src(5, 3) - luc0 * luc6 * luc15 - luc2 * luc7 * luc16 - luc5 * luc8 * luc17);
0253       auto luc19 =
0254           (src(5, 4) - luc0 * luc10 * luc15 - luc2 * luc11 * luc16 - luc5 * luc12 * luc17 - luc9 * luc13 * luc18);
0255       auto luc20 = src(5, 5) - (luc0 * luc15 * luc15 + luc2 * luc16 * luc16 + luc5 * luc17 * luc17 +
0256                                 luc9 * luc18 * luc18 + luc14 * luc19 * luc19);
0257       luc20 = F(1.0) / luc20;
0258 
0259       auto li21 = -luc1 * luc0;
0260       auto li32 = -luc2 * luc4;
0261       auto li31 = (luc1 * (luc2 * luc4) - luc3) * luc0;
0262       auto li43 = -(luc8 * luc5);
0263       auto li42 = (luc4 * luc8 * luc5 - luc7) * luc2;
0264       auto li41 = (-luc1 * (luc2 * luc4) * (luc8 * luc5) + luc1 * (luc2 * luc7) + luc3 * (luc8 * luc5) - luc6) * luc0;
0265       auto li54 = -luc13 * luc9;
0266       auto li53 = (luc13 * luc8 * luc9 - luc12) * luc5;
0267       auto li52 = (-luc4 * luc8 * luc13 * luc5 * luc9 + luc4 * luc12 * luc5 + luc7 * luc13 * luc9 - luc11) * luc2;
0268       auto li51 = (luc1 * luc4 * luc8 * luc13 * luc2 * luc5 * luc9 - luc13 * luc8 * luc3 * luc9 * luc5 -
0269                    luc12 * luc4 * luc1 * luc2 * luc5 - luc13 * luc7 * luc1 * luc9 * luc2 + luc11 * luc1 * luc2 +
0270                    luc12 * luc3 * luc5 + luc13 * luc6 * luc9 - luc10) *
0271                   luc0;
0272 
0273       auto li65 = -luc19 * luc14;
0274       auto li64 = (luc19 * luc14 * luc13 - luc18) * luc9;
0275       auto li63 =
0276           (-luc8 * luc13 * (luc19 * luc14) * luc9 + luc8 * luc9 * luc18 + luc12 * (luc19 * luc14) - luc17) * luc5;
0277       auto li62 = (luc4 * (luc8 * luc9) * luc13 * luc5 * (luc19 * luc14) - luc18 * luc4 * (luc8 * luc9) * luc5 -
0278                    luc19 * luc12 * luc4 * luc14 * luc5 - luc19 * luc13 * luc7 * luc14 * luc9 + luc17 * luc4 * luc5 +
0279                    luc18 * luc7 * luc9 + luc19 * luc11 * luc14 - luc16) *
0280                   luc2;
0281       auto li61 =
0282           (-luc19 * luc13 * luc8 * luc4 * luc1 * luc2 * luc5 * luc9 * luc14 +
0283            luc18 * luc8 * luc4 * luc1 * luc2 * luc5 * luc9 + luc19 * luc12 * luc4 * luc1 * luc2 * luc5 * luc14 +
0284            luc19 * luc13 * luc7 * luc1 * luc2 * luc9 * luc14 + luc19 * luc13 * luc8 * luc3 * luc5 * luc9 * luc14 -
0285            luc17 * luc4 * luc1 * luc2 * luc5 - luc18 * luc7 * luc1 * luc2 * luc9 - luc19 * luc11 * luc1 * luc2 * luc14 -
0286            luc18 * luc8 * luc3 * luc5 * luc9 - luc19 * luc12 * luc3 * luc5 * luc14 -
0287            luc19 * luc13 * luc6 * luc9 * luc14 + luc16 * luc1 * luc2 + luc17 * luc3 * luc5 + luc18 * luc6 * luc9 +
0288            luc19 * luc10 * luc14 - luc15) *
0289           luc0;
0290 
0291       dst(0, 0) = luc20 * li61 * li61 + luc14 * li51 * li51 + luc9 * li41 * li41 + luc5 * li31 * li31 +
0292                   luc2 * li21 * li21 + luc0;
0293       dst(1, 0) = luc20 * li61 * li62 + luc14 * li51 * li52 + luc9 * li41 * li42 + luc5 * li31 * li32 + luc2 * li21;
0294       dst(1, 1) = luc20 * li62 * li62 + luc14 * li52 * li52 + luc9 * li42 * li42 + luc5 * li32 * li32 + luc2;
0295       dst(2, 0) = luc20 * li61 * li63 + luc14 * li51 * li53 + luc9 * li41 * li43 + luc5 * li31;
0296       dst(2, 1) = luc20 * li62 * li63 + luc14 * li52 * li53 + luc9 * li42 * li43 + luc5 * li32;
0297       dst(2, 2) = luc20 * li63 * li63 + luc14 * li53 * li53 + luc9 * li43 * li43 + luc5;
0298       dst(3, 0) = luc20 * li61 * li64 + luc14 * li51 * li54 + luc9 * li41;
0299       dst(3, 1) = luc20 * li62 * li64 + luc14 * li52 * li54 + luc9 * li42;
0300       dst(3, 2) = luc20 * li63 * li64 + luc14 * li53 * li54 + luc9 * li43;
0301       dst(3, 3) = luc20 * li64 * li64 + luc14 * li54 * li54 + luc9;
0302       dst(4, 0) = luc20 * li61 * li65 + luc14 * li51;
0303       dst(4, 1) = luc20 * li62 * li65 + luc14 * li52;
0304       dst(4, 2) = luc20 * li63 * li65 + luc14 * li53;
0305       dst(4, 3) = luc20 * li64 * li65 + luc14 * li54;
0306       dst(4, 4) = luc20 * li65 * li65 + luc14;
0307       dst(5, 0) = luc20 * li61;
0308       dst(5, 1) = luc20 * li62;
0309       dst(5, 2) = luc20 * li63;
0310       dst(5, 3) = luc20 * li64;
0311       dst(5, 4) = luc20 * li65;
0312       dst(5, 5) = luc20;
0313     }
0314 
0315     template <typename M>
0316     inline constexpr void symmetrize11(M& dst) {}
0317 
0318     template <typename M>
0319     inline constexpr void symmetrize22(M& dst) {
0320       dst(0, 1) = dst(1, 0);
0321     }
0322 
0323     template <typename M>
0324     inline constexpr void symmetrize33(M& dst) {
0325       symmetrize22(dst);
0326       dst(0, 2) = dst(2, 0);
0327       dst(1, 2) = dst(2, 1);
0328     }
0329 
0330     template <typename M>
0331     inline constexpr void symmetrize44(M& dst) {
0332       symmetrize33(dst);
0333       dst(0, 3) = dst(3, 0);
0334       dst(1, 3) = dst(3, 1);
0335       dst(2, 3) = dst(3, 2);
0336     }
0337 
0338     template <typename M>
0339     inline constexpr void symmetrize55(M& dst) {
0340       symmetrize44(dst);
0341       dst(0, 4) = dst(4, 0);
0342       dst(1, 4) = dst(4, 1);
0343       dst(2, 4) = dst(4, 2);
0344       dst(3, 4) = dst(4, 3);
0345     }
0346 
0347     template <typename M>
0348     inline constexpr void symmetrize66(M& dst) {
0349       symmetrize55(dst);
0350       dst(0, 5) = dst(5, 0);
0351       dst(1, 5) = dst(5, 1);
0352       dst(2, 5) = dst(5, 2);
0353       dst(3, 5) = dst(5, 3);
0354       dst(4, 5) = dst(5, 4);
0355     }
0356 
0357     template <typename M1, typename M2, int N>
0358     struct Inverter {
0359       static constexpr void eval(M1 const& src, M2& dst) { dst = src.inverse(); }
0360     };
0361 
0362     template <typename M1, typename M2>
0363     struct Inverter<M1, M2, 1> {
0364       static constexpr void eval(M1 const& src, M2& dst) { invert11(src, dst); }
0365     };
0366 
0367     template <typename M1, typename M2>
0368     struct Inverter<M1, M2, 2> {
0369       static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
0370         invert22(src, dst);
0371         symmetrize22(dst);
0372       }
0373     };
0374 
0375     template <typename M1, typename M2>
0376     struct Inverter<M1, M2, 3> {
0377       static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
0378         invert33(src, dst);
0379         symmetrize33(dst);
0380       }
0381     };
0382 
0383     template <typename M1, typename M2>
0384     struct Inverter<M1, M2, 4> {
0385       static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
0386         invert44(src, dst);
0387         symmetrize44(dst);
0388       }
0389     };
0390 
0391     template <typename M1, typename M2>
0392     struct Inverter<M1, M2, 5> {
0393       static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
0394         invert55(src, dst);
0395         symmetrize55(dst);
0396       }
0397     };
0398 
0399     template <typename M1, typename M2>
0400     struct Inverter<M1, M2, 6> {
0401       static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
0402         invert66(src, dst);
0403         symmetrize66(dst);
0404       }
0405     };
0406 
0407     // Eigen interface
0408     template <typename M1, typename M2>
0409     inline constexpr void __attribute__((always_inline)) invert(M1 const& src, M2& dst) {
0410       if constexpr (M2::ColsAtCompileTime < 7)
0411         Inverter<M1, M2, M2::ColsAtCompileTime>::eval(src, dst);
0412       else
0413         invertNN(src, dst);
0414     }
0415 
0416   }  // namespace cholesky
0417 }  // namespace math
0418 
0419 #endif  // DataFormat_Math_choleskyInversion_h