Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-02-21 23:14:15

0001 #ifndef RecoTracker_MkFitCore_src_Matriplex_Matriplex_h
0002 #define RecoTracker_MkFitCore_src_Matriplex_Matriplex_h
0003 
0004 #include "MatriplexCommon.h"
0005 
0006 namespace Matriplex {
0007 
0008   //------------------------------------------------------------------------------
0009 
0010   template <typename T, idx_t D1, idx_t D2, idx_t N>
0011   class Matriplex {
0012   public:
0013     typedef T value_type;
0014 
0015     /// return no. of matrix rows
0016     static constexpr int kRows = D1;
0017     /// return no. of matrix columns
0018     static constexpr int kCols = D2;
0019     /// return no of elements: rows*columns
0020     static constexpr int kSize = D1 * D2;
0021     /// size of the whole matriplex
0022     static constexpr int kTotSize = N * kSize;
0023 
0024     T fArray[kTotSize] __attribute__((aligned(64)));
0025 
0026     Matriplex() {}
0027     Matriplex(T v) { setVal(v); }
0028 
0029     idx_t plexSize() const { return N; }
0030 
0031     void setVal(T v) {
0032       for (idx_t i = 0; i < kTotSize; ++i) {
0033         fArray[i] = v;
0034       }
0035     }
0036 
0037     void add(const Matriplex& v) {
0038       for (idx_t i = 0; i < kTotSize; ++i) {
0039         fArray[i] += v.fArray[i];
0040       }
0041     }
0042 
0043     void scale(T scale) {
0044       for (idx_t i = 0; i < kTotSize; ++i) {
0045         fArray[i] *= scale;
0046       }
0047     }
0048 
0049     T operator[](idx_t xx) const { return fArray[xx]; }
0050     T& operator[](idx_t xx) { return fArray[xx]; }
0051 
0052     const T& constAt(idx_t n, idx_t i, idx_t j) const { return fArray[(i * D2 + j) * N + n]; }
0053 
0054     T& At(idx_t n, idx_t i, idx_t j) { return fArray[(i * D2 + j) * N + n]; }
0055 
0056     T& operator()(idx_t n, idx_t i, idx_t j) { return fArray[(i * D2 + j) * N + n]; }
0057     const T& operator()(idx_t n, idx_t i, idx_t j) const { return fArray[(i * D2 + j) * N + n]; }
0058 
0059     Matriplex& operator=(const Matriplex& m) {
0060       memcpy(fArray, m.fArray, sizeof(T) * kTotSize);
0061       return *this;
0062     }
0063 
0064     void copySlot(idx_t n, const Matriplex& m) {
0065       for (idx_t i = n; i < kTotSize; i += N) {
0066         fArray[i] = m.fArray[i];
0067       }
0068     }
0069 
0070     void copyIn(idx_t n, const T* arr) {
0071       for (idx_t i = n; i < kTotSize; i += N) {
0072         fArray[i] = *(arr++);
0073       }
0074     }
0075 
0076     void copyIn(idx_t n, const Matriplex& m, idx_t in) {
0077       for (idx_t i = n; i < kTotSize; i += N, in += N) {
0078         fArray[i] = m[in];
0079       }
0080     }
0081 
0082     void copy(idx_t n, idx_t in) {
0083       for (idx_t i = n; i < kTotSize; i += N, in += N) {
0084         fArray[i] = fArray[in];
0085       }
0086     }
0087 
0088 #if defined(AVX512_INTRINSICS)
0089 
0090     template <typename U>
0091     void slurpIn(const T* arr, __m512i& vi, const U&, const int N_proc = N) {
0092       //_mm512_prefetch_i32gather_ps(vi, arr, 1, _MM_HINT_T0);
0093 
0094       const __m512 src = {0};
0095       const __mmask16 k = N_proc == N ? -1 : (1 << N_proc) - 1;
0096 
0097       for (int i = 0; i < kSize; ++i, ++arr) {
0098         //_mm512_prefetch_i32gather_ps(vi, arr+2, 1, _MM_HINT_NTA);
0099 
0100         __m512 reg = _mm512_mask_i32gather_ps(src, k, vi, arr, sizeof(U));
0101         _mm512_mask_store_ps(&fArray[i * N], k, reg);
0102       }
0103     }
0104 
0105     // Experimental methods, slurpIn() seems to be at least as fast.
0106     // See comments in mkFit/MkFitter.cc MkFitter::addBestHit().
0107     void ChewIn(const char* arr, int off, int vi[N], const char* tmp, __m512i& ui) {
0108       // This is a hack ... we know sizeof(Hit) = 64 = cache line = vector width.
0109 
0110       for (int i = 0; i < N; ++i) {
0111         __m512 reg = _mm512_load_ps(arr + vi[i]);
0112         _mm512_store_ps((void*)(tmp + 64 * i), reg);
0113       }
0114 
0115       for (int i = 0; i < kSize; ++i) {
0116         __m512 reg = _mm512_i32gather_ps(ui, tmp + off + i * sizeof(T), 1);
0117         _mm512_store_ps(&fArray[i * N], reg);
0118       }
0119     }
0120 
0121     void Contaginate(const char* arr, int vi[N], const char* tmp) {
0122       // This is a hack ... we know sizeof(Hit) = 64 = cache line = vector width.
0123 
0124       for (int i = 0; i < N; ++i) {
0125         __m512 reg = _mm512_load_ps(arr + vi[i]);
0126         _mm512_store_ps((void*)(tmp + 64 * i), reg);
0127       }
0128     }
0129 
0130     void Plexify(const char* tmp, __m512i& ui) {
0131       for (int i = 0; i < kSize; ++i) {
0132         __m512 reg = _mm512_i32gather_ps(ui, tmp + i * sizeof(T), 1);
0133         _mm512_store_ps(&fArray[i * N], reg);
0134       }
0135     }
0136 
0137 #elif defined(AVX2_INTRINSICS)
0138 
0139     template <typename U>
0140     void slurpIn(const T* arr, __m256i& vi, const U&, const int N_proc = N) {
0141       // Casts to float* needed to "support" also T=HitOnTrack.
0142       // Note that sizeof(float) == sizeof(HitOnTrack) == 4.
0143 
0144       const __m256 src = {0};
0145 
0146       __m256i k = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
0147       __m256i k_sel = _mm256_set1_epi32(N_proc);
0148       __m256i k_master = _mm256_cmpgt_epi32(k_sel, k);
0149 
0150       k = k_master;
0151       for (int i = 0; i < kSize; ++i, ++arr) {
0152         __m256 reg = _mm256_mask_i32gather_ps(src, (float*)arr, vi, (__m256)k, sizeof(U));
0153         // Restore mask (docs say gather clears it but it doesn't seem to).
0154         k = k_master;
0155         _mm256_maskstore_ps((float*)&fArray[i * N], k, reg);
0156       }
0157     }
0158 
0159 #else
0160 
0161     void slurpIn(const T* arr, int vi[N], const int N_proc = N) {
0162       // Separate N_proc == N case (gains about 7% in fit test).
0163       if (N_proc == N) {
0164         for (int i = 0; i < kSize; ++i) {
0165           for (int j = 0; j < N; ++j) {
0166             fArray[i * N + j] = *(arr + i + vi[j]);
0167           }
0168         }
0169       } else {
0170         for (int i = 0; i < kSize; ++i) {
0171           for (int j = 0; j < N_proc; ++j) {
0172             fArray[i * N + j] = *(arr + i + vi[j]);
0173           }
0174         }
0175       }
0176     }
0177 
0178 #endif
0179 
0180     void copyOut(idx_t n, T* arr) const {
0181       for (idx_t i = n; i < kTotSize; i += N) {
0182         *(arr++) = fArray[i];
0183       }
0184     }
0185   };
0186 
0187   template <typename T, idx_t D1, idx_t D2, idx_t N>
0188   using MPlex = Matriplex<T, D1, D2, N>;
0189 
0190   //==============================================================================
0191   // Multiplications
0192   //==============================================================================
0193 
0194   template <typename T, idx_t D1, idx_t D2, idx_t D3, idx_t N>
0195   void multiplyGeneral(const MPlex<T, D1, D2, N>& A, const MPlex<T, D2, D3, N>& B, MPlex<T, D1, D3, N>& C) {
0196     for (idx_t i = 0; i < D1; ++i) {
0197       for (idx_t j = 0; j < D3; ++j) {
0198         const idx_t ijo = N * (i * D3 + j);
0199 
0200 #pragma omp simd
0201         for (idx_t n = 0; n < N; ++n) {
0202           C.fArray[ijo + n] = 0;
0203         }
0204 
0205         for (idx_t k = 0; k < D2; ++k) {
0206           const idx_t iko = N * (i * D2 + k);
0207           const idx_t kjo = N * (k * D3 + j);
0208 
0209 #pragma omp simd
0210           for (idx_t n = 0; n < N; ++n) {
0211             C.fArray[ijo + n] += A.fArray[iko + n] * B.fArray[kjo + n];
0212           }
0213         }
0214       }
0215     }
0216   }
0217 
0218   //------------------------------------------------------------------------------
0219 
0220   template <typename T, idx_t D, idx_t N>
0221   struct MultiplyCls {
0222     static void multiply(const MPlex<T, D, D, N>& A, const MPlex<T, D, D, N>& B, MPlex<T, D, D, N>& C) {
0223       throw std::runtime_error("general multiplication not supported, well, call multiplyGeneral()");
0224     }
0225   };
0226 
0227   template <typename T, idx_t N>
0228   struct MultiplyCls<T, 3, N> {
0229     static void multiply(const MPlex<T, 3, 3, N>& A, const MPlex<T, 3, 3, N>& B, MPlex<T, 3, 3, N>& C) {
0230       const T* a = A.fArray;
0231       ASSUME_ALIGNED(a, 64);
0232       const T* b = B.fArray;
0233       ASSUME_ALIGNED(b, 64);
0234       T* c = C.fArray;
0235       ASSUME_ALIGNED(c, 64);
0236 
0237 #pragma omp simd
0238       for (idx_t n = 0; n < N; ++n) {
0239         c[0 * N + n] = a[0 * N + n] * b[0 * N + n] + a[1 * N + n] * b[3 * N + n] + a[2 * N + n] * b[6 * N + n];
0240         c[1 * N + n] = a[0 * N + n] * b[1 * N + n] + a[1 * N + n] * b[4 * N + n] + a[2 * N + n] * b[7 * N + n];
0241         c[2 * N + n] = a[0 * N + n] * b[2 * N + n] + a[1 * N + n] * b[5 * N + n] + a[2 * N + n] * b[8 * N + n];
0242         c[3 * N + n] = a[3 * N + n] * b[0 * N + n] + a[4 * N + n] * b[3 * N + n] + a[5 * N + n] * b[6 * N + n];
0243         c[4 * N + n] = a[3 * N + n] * b[1 * N + n] + a[4 * N + n] * b[4 * N + n] + a[5 * N + n] * b[7 * N + n];
0244         c[5 * N + n] = a[3 * N + n] * b[2 * N + n] + a[4 * N + n] * b[5 * N + n] + a[5 * N + n] * b[8 * N + n];
0245         c[6 * N + n] = a[6 * N + n] * b[0 * N + n] + a[7 * N + n] * b[3 * N + n] + a[8 * N + n] * b[6 * N + n];
0246         c[7 * N + n] = a[6 * N + n] * b[1 * N + n] + a[7 * N + n] * b[4 * N + n] + a[8 * N + n] * b[7 * N + n];
0247         c[8 * N + n] = a[6 * N + n] * b[2 * N + n] + a[7 * N + n] * b[5 * N + n] + a[8 * N + n] * b[8 * N + n];
0248       }
0249     }
0250   };
0251 
0252   template <typename T, idx_t N>
0253   struct MultiplyCls<T, 6, N> {
0254     static void multiply(const MPlex<T, 6, 6, N>& A, const MPlex<T, 6, 6, N>& B, MPlex<T, 6, 6, N>& C) {
0255       const T* a = A.fArray;
0256       ASSUME_ALIGNED(a, 64);
0257       const T* b = B.fArray;
0258       ASSUME_ALIGNED(b, 64);
0259       T* c = C.fArray;
0260       ASSUME_ALIGNED(c, 64);
0261 #pragma omp simd
0262       for (idx_t n = 0; n < N; ++n) {
0263         c[0 * N + n] = a[0 * N + n] * b[0 * N + n] + a[1 * N + n] * b[6 * N + n] + a[2 * N + n] * b[12 * N + n] +
0264                        a[3 * N + n] * b[18 * N + n] + a[4 * N + n] * b[24 * N + n] + a[5 * N + n] * b[30 * N + n];
0265         c[1 * N + n] = a[0 * N + n] * b[1 * N + n] + a[1 * N + n] * b[7 * N + n] + a[2 * N + n] * b[13 * N + n] +
0266                        a[3 * N + n] * b[19 * N + n] + a[4 * N + n] * b[25 * N + n] + a[5 * N + n] * b[31 * N + n];
0267         c[2 * N + n] = a[0 * N + n] * b[2 * N + n] + a[1 * N + n] * b[8 * N + n] + a[2 * N + n] * b[14 * N + n] +
0268                        a[3 * N + n] * b[20 * N + n] + a[4 * N + n] * b[26 * N + n] + a[5 * N + n] * b[32 * N + n];
0269         c[3 * N + n] = a[0 * N + n] * b[3 * N + n] + a[1 * N + n] * b[9 * N + n] + a[2 * N + n] * b[15 * N + n] +
0270                        a[3 * N + n] * b[21 * N + n] + a[4 * N + n] * b[27 * N + n] + a[5 * N + n] * b[33 * N + n];
0271         c[4 * N + n] = a[0 * N + n] * b[4 * N + n] + a[1 * N + n] * b[10 * N + n] + a[2 * N + n] * b[16 * N + n] +
0272                        a[3 * N + n] * b[22 * N + n] + a[4 * N + n] * b[28 * N + n] + a[5 * N + n] * b[34 * N + n];
0273         c[5 * N + n] = a[0 * N + n] * b[5 * N + n] + a[1 * N + n] * b[11 * N + n] + a[2 * N + n] * b[17 * N + n] +
0274                        a[3 * N + n] * b[23 * N + n] + a[4 * N + n] * b[29 * N + n] + a[5 * N + n] * b[35 * N + n];
0275         c[6 * N + n] = a[6 * N + n] * b[0 * N + n] + a[7 * N + n] * b[6 * N + n] + a[8 * N + n] * b[12 * N + n] +
0276                        a[9 * N + n] * b[18 * N + n] + a[10 * N + n] * b[24 * N + n] + a[11 * N + n] * b[30 * N + n];
0277         c[7 * N + n] = a[6 * N + n] * b[1 * N + n] + a[7 * N + n] * b[7 * N + n] + a[8 * N + n] * b[13 * N + n] +
0278                        a[9 * N + n] * b[19 * N + n] + a[10 * N + n] * b[25 * N + n] + a[11 * N + n] * b[31 * N + n];
0279         c[8 * N + n] = a[6 * N + n] * b[2 * N + n] + a[7 * N + n] * b[8 * N + n] + a[8 * N + n] * b[14 * N + n] +
0280                        a[9 * N + n] * b[20 * N + n] + a[10 * N + n] * b[26 * N + n] + a[11 * N + n] * b[32 * N + n];
0281         c[9 * N + n] = a[6 * N + n] * b[3 * N + n] + a[7 * N + n] * b[9 * N + n] + a[8 * N + n] * b[15 * N + n] +
0282                        a[9 * N + n] * b[21 * N + n] + a[10 * N + n] * b[27 * N + n] + a[11 * N + n] * b[33 * N + n];
0283         c[10 * N + n] = a[6 * N + n] * b[4 * N + n] + a[7 * N + n] * b[10 * N + n] + a[8 * N + n] * b[16 * N + n] +
0284                         a[9 * N + n] * b[22 * N + n] + a[10 * N + n] * b[28 * N + n] + a[11 * N + n] * b[34 * N + n];
0285         c[11 * N + n] = a[6 * N + n] * b[5 * N + n] + a[7 * N + n] * b[11 * N + n] + a[8 * N + n] * b[17 * N + n] +
0286                         a[9 * N + n] * b[23 * N + n] + a[10 * N + n] * b[29 * N + n] + a[11 * N + n] * b[35 * N + n];
0287         c[12 * N + n] = a[12 * N + n] * b[0 * N + n] + a[13 * N + n] * b[6 * N + n] + a[14 * N + n] * b[12 * N + n] +
0288                         a[15 * N + n] * b[18 * N + n] + a[16 * N + n] * b[24 * N + n] + a[17 * N + n] * b[30 * N + n];
0289         c[13 * N + n] = a[12 * N + n] * b[1 * N + n] + a[13 * N + n] * b[7 * N + n] + a[14 * N + n] * b[13 * N + n] +
0290                         a[15 * N + n] * b[19 * N + n] + a[16 * N + n] * b[25 * N + n] + a[17 * N + n] * b[31 * N + n];
0291         c[14 * N + n] = a[12 * N + n] * b[2 * N + n] + a[13 * N + n] * b[8 * N + n] + a[14 * N + n] * b[14 * N + n] +
0292                         a[15 * N + n] * b[20 * N + n] + a[16 * N + n] * b[26 * N + n] + a[17 * N + n] * b[32 * N + n];
0293         c[15 * N + n] = a[12 * N + n] * b[3 * N + n] + a[13 * N + n] * b[9 * N + n] + a[14 * N + n] * b[15 * N + n] +
0294                         a[15 * N + n] * b[21 * N + n] + a[16 * N + n] * b[27 * N + n] + a[17 * N + n] * b[33 * N + n];
0295         c[16 * N + n] = a[12 * N + n] * b[4 * N + n] + a[13 * N + n] * b[10 * N + n] + a[14 * N + n] * b[16 * N + n] +
0296                         a[15 * N + n] * b[22 * N + n] + a[16 * N + n] * b[28 * N + n] + a[17 * N + n] * b[34 * N + n];
0297         c[17 * N + n] = a[12 * N + n] * b[5 * N + n] + a[13 * N + n] * b[11 * N + n] + a[14 * N + n] * b[17 * N + n] +
0298                         a[15 * N + n] * b[23 * N + n] + a[16 * N + n] * b[29 * N + n] + a[17 * N + n] * b[35 * N + n];
0299         c[18 * N + n] = a[18 * N + n] * b[0 * N + n] + a[19 * N + n] * b[6 * N + n] + a[20 * N + n] * b[12 * N + n] +
0300                         a[21 * N + n] * b[18 * N + n] + a[22 * N + n] * b[24 * N + n] + a[23 * N + n] * b[30 * N + n];
0301         c[19 * N + n] = a[18 * N + n] * b[1 * N + n] + a[19 * N + n] * b[7 * N + n] + a[20 * N + n] * b[13 * N + n] +
0302                         a[21 * N + n] * b[19 * N + n] + a[22 * N + n] * b[25 * N + n] + a[23 * N + n] * b[31 * N + n];
0303         c[20 * N + n] = a[18 * N + n] * b[2 * N + n] + a[19 * N + n] * b[8 * N + n] + a[20 * N + n] * b[14 * N + n] +
0304                         a[21 * N + n] * b[20 * N + n] + a[22 * N + n] * b[26 * N + n] + a[23 * N + n] * b[32 * N + n];
0305         c[21 * N + n] = a[18 * N + n] * b[3 * N + n] + a[19 * N + n] * b[9 * N + n] + a[20 * N + n] * b[15 * N + n] +
0306                         a[21 * N + n] * b[21 * N + n] + a[22 * N + n] * b[27 * N + n] + a[23 * N + n] * b[33 * N + n];
0307         c[22 * N + n] = a[18 * N + n] * b[4 * N + n] + a[19 * N + n] * b[10 * N + n] + a[20 * N + n] * b[16 * N + n] +
0308                         a[21 * N + n] * b[22 * N + n] + a[22 * N + n] * b[28 * N + n] + a[23 * N + n] * b[34 * N + n];
0309         c[23 * N + n] = a[18 * N + n] * b[5 * N + n] + a[19 * N + n] * b[11 * N + n] + a[20 * N + n] * b[17 * N + n] +
0310                         a[21 * N + n] * b[23 * N + n] + a[22 * N + n] * b[29 * N + n] + a[23 * N + n] * b[35 * N + n];
0311         c[24 * N + n] = a[24 * N + n] * b[0 * N + n] + a[25 * N + n] * b[6 * N + n] + a[26 * N + n] * b[12 * N + n] +
0312                         a[27 * N + n] * b[18 * N + n] + a[28 * N + n] * b[24 * N + n] + a[29 * N + n] * b[30 * N + n];
0313         c[25 * N + n] = a[24 * N + n] * b[1 * N + n] + a[25 * N + n] * b[7 * N + n] + a[26 * N + n] * b[13 * N + n] +
0314                         a[27 * N + n] * b[19 * N + n] + a[28 * N + n] * b[25 * N + n] + a[29 * N + n] * b[31 * N + n];
0315         c[26 * N + n] = a[24 * N + n] * b[2 * N + n] + a[25 * N + n] * b[8 * N + n] + a[26 * N + n] * b[14 * N + n] +
0316                         a[27 * N + n] * b[20 * N + n] + a[28 * N + n] * b[26 * N + n] + a[29 * N + n] * b[32 * N + n];
0317         c[27 * N + n] = a[24 * N + n] * b[3 * N + n] + a[25 * N + n] * b[9 * N + n] + a[26 * N + n] * b[15 * N + n] +
0318                         a[27 * N + n] * b[21 * N + n] + a[28 * N + n] * b[27 * N + n] + a[29 * N + n] * b[33 * N + n];
0319         c[28 * N + n] = a[24 * N + n] * b[4 * N + n] + a[25 * N + n] * b[10 * N + n] + a[26 * N + n] * b[16 * N + n] +
0320                         a[27 * N + n] * b[22 * N + n] + a[28 * N + n] * b[28 * N + n] + a[29 * N + n] * b[34 * N + n];
0321         c[29 * N + n] = a[24 * N + n] * b[5 * N + n] + a[25 * N + n] * b[11 * N + n] + a[26 * N + n] * b[17 * N + n] +
0322                         a[27 * N + n] * b[23 * N + n] + a[28 * N + n] * b[29 * N + n] + a[29 * N + n] * b[35 * N + n];
0323         c[30 * N + n] = a[30 * N + n] * b[0 * N + n] + a[31 * N + n] * b[6 * N + n] + a[32 * N + n] * b[12 * N + n] +
0324                         a[33 * N + n] * b[18 * N + n] + a[34 * N + n] * b[24 * N + n] + a[35 * N + n] * b[30 * N + n];
0325         c[31 * N + n] = a[30 * N + n] * b[1 * N + n] + a[31 * N + n] * b[7 * N + n] + a[32 * N + n] * b[13 * N + n] +
0326                         a[33 * N + n] * b[19 * N + n] + a[34 * N + n] * b[25 * N + n] + a[35 * N + n] * b[31 * N + n];
0327         c[32 * N + n] = a[30 * N + n] * b[2 * N + n] + a[31 * N + n] * b[8 * N + n] + a[32 * N + n] * b[14 * N + n] +
0328                         a[33 * N + n] * b[20 * N + n] + a[34 * N + n] * b[26 * N + n] + a[35 * N + n] * b[32 * N + n];
0329         c[33 * N + n] = a[30 * N + n] * b[3 * N + n] + a[31 * N + n] * b[9 * N + n] + a[32 * N + n] * b[15 * N + n] +
0330                         a[33 * N + n] * b[21 * N + n] + a[34 * N + n] * b[27 * N + n] + a[35 * N + n] * b[33 * N + n];
0331         c[34 * N + n] = a[30 * N + n] * b[4 * N + n] + a[31 * N + n] * b[10 * N + n] + a[32 * N + n] * b[16 * N + n] +
0332                         a[33 * N + n] * b[22 * N + n] + a[34 * N + n] * b[28 * N + n] + a[35 * N + n] * b[34 * N + n];
0333         c[35 * N + n] = a[30 * N + n] * b[5 * N + n] + a[31 * N + n] * b[11 * N + n] + a[32 * N + n] * b[17 * N + n] +
0334                         a[33 * N + n] * b[23 * N + n] + a[34 * N + n] * b[29 * N + n] + a[35 * N + n] * b[35 * N + n];
0335       }
0336     }
0337   };
0338 
0339   template <typename T, idx_t D, idx_t N>
0340   void multiply(const MPlex<T, D, D, N>& A, const MPlex<T, D, D, N>& B, MPlex<T, D, D, N>& C) {
0341 #ifdef DEBUG
0342     printf("Multipl %d %d\n", D, N);
0343 #endif
0344 
0345     MultiplyCls<T, D, N>::multiply(A, B, C);
0346   }
0347 
0348   //==============================================================================
0349   // Cramer inversion
0350   //==============================================================================
0351 
0352   template <typename T, idx_t D, idx_t N>
0353   struct CramerInverter {
0354     static void invert(MPlex<T, D, D, N>& A, double* determ = nullptr) {
0355       throw std::runtime_error("general cramer inversion not supported");
0356     }
0357   };
0358 
0359   template <typename T, idx_t N>
0360   struct CramerInverter<T, 2, N> {
0361     static void invert(MPlex<T, 2, 2, N>& A, double* determ = nullptr) {
0362       typedef T TT;
0363 
0364       T* a = A.fArray;
0365       ASSUME_ALIGNED(a, 64);
0366 
0367 #pragma omp simd
0368       for (idx_t n = 0; n < N; ++n) {
0369         // Force determinant calculation in double precision.
0370         const double det = (double)a[0 * N + n] * a[3 * N + n] - (double)a[2 * N + n] * a[1 * N + n];
0371         if (determ)
0372           determ[n] = det;
0373 
0374         const TT s = TT(1) / det;
0375         const TT tmp = s * a[3 * N + n];
0376         a[1 * N + n] *= -s;
0377         a[2 * N + n] *= -s;
0378         a[3 * N + n] = s * a[0 * N + n];
0379         a[0 * N + n] = tmp;
0380       }
0381     }
0382   };
0383 
0384   template <typename T, idx_t N>
0385   struct CramerInverter<T, 3, N> {
0386     static void invert(MPlex<T, 3, 3, N>& A, double* determ = nullptr) {
0387       typedef T TT;
0388 
0389       T* a = A.fArray;
0390       ASSUME_ALIGNED(a, 64);
0391 
0392 #pragma omp simd
0393       for (idx_t n = 0; n < N; ++n) {
0394         const TT c00 = a[4 * N + n] * a[8 * N + n] - a[5 * N + n] * a[7 * N + n];
0395         const TT c01 = a[5 * N + n] * a[6 * N + n] - a[3 * N + n] * a[8 * N + n];
0396         const TT c02 = a[3 * N + n] * a[7 * N + n] - a[4 * N + n] * a[6 * N + n];
0397         const TT c10 = a[7 * N + n] * a[2 * N + n] - a[8 * N + n] * a[1 * N + n];
0398         const TT c11 = a[8 * N + n] * a[0 * N + n] - a[6 * N + n] * a[2 * N + n];
0399         const TT c12 = a[6 * N + n] * a[1 * N + n] - a[7 * N + n] * a[0 * N + n];
0400         const TT c20 = a[1 * N + n] * a[5 * N + n] - a[2 * N + n] * a[4 * N + n];
0401         const TT c21 = a[2 * N + n] * a[3 * N + n] - a[0 * N + n] * a[5 * N + n];
0402         const TT c22 = a[0 * N + n] * a[4 * N + n] - a[1 * N + n] * a[3 * N + n];
0403 
0404         // Force determinant calculation in double precision.
0405         const double det = (double)a[0 * N + n] * c00 + (double)a[1 * N + n] * c01 + (double)a[2 * N + n] * c02;
0406         if (determ)
0407           determ[n] = det;
0408 
0409         const TT s = TT(1) / det;
0410         a[0 * N + n] = s * c00;
0411         a[1 * N + n] = s * c10;
0412         a[2 * N + n] = s * c20;
0413         a[3 * N + n] = s * c01;
0414         a[4 * N + n] = s * c11;
0415         a[5 * N + n] = s * c21;
0416         a[6 * N + n] = s * c02;
0417         a[7 * N + n] = s * c12;
0418         a[8 * N + n] = s * c22;
0419       }
0420     }
0421   };
0422 
0423   template <typename T, idx_t D, idx_t N>
0424   void invertCramer(MPlex<T, D, D, N>& A, double* determ = nullptr) {
0425     CramerInverter<T, D, N>::invert(A, determ);
0426   }
0427 
0428   //==============================================================================
0429   // Cholesky inversion
0430   //==============================================================================
0431 
0432   template <typename T, idx_t D, idx_t N>
0433   struct CholeskyInverter {
0434     static void invert(MPlex<T, D, D, N>& A) { throw std::runtime_error("general cholesky inversion not supported"); }
0435   };
0436 
0437   template <typename T, idx_t N>
0438   struct CholeskyInverter<T, 3, N> {
0439     // Note: this only works on symmetric matrices.
0440     // Optimized version for positive definite matrices, no checks.
0441     static void invert(MPlex<T, 3, 3, N>& A) {
0442       typedef T TT;
0443 
0444       T* a = A.fArray;
0445       ASSUME_ALIGNED(a, 64);
0446 
0447 #pragma omp simd
0448       for (idx_t n = 0; n < N; ++n) {
0449         TT l0 = std::sqrt(T(1) / a[0 * N + n]);
0450         TT l1 = a[3 * N + n] * l0;
0451         TT l2 = a[4 * N + n] - l1 * l1;
0452         l2 = std::sqrt(T(1) / l2);
0453         TT l3 = a[6 * N + n] * l0;
0454         TT l4 = (a[7 * N + n] - l1 * l3) * l2;
0455         TT l5 = a[8 * N + n] - (l3 * l3 + l4 * l4);
0456         l5 = std::sqrt(T(1) / l5);
0457 
0458         // decomposition done
0459 
0460         l3 = (l1 * l4 * l2 - l3) * l0 * l5;
0461         l1 = -l1 * l0 * l2;
0462         l4 = -l4 * l2 * l5;
0463 
0464         a[0 * N + n] = l3 * l3 + l1 * l1 + l0 * l0;
0465         a[1 * N + n] = a[3 * N + n] = l3 * l4 + l1 * l2;
0466         a[4 * N + n] = l4 * l4 + l2 * l2;
0467         a[2 * N + n] = a[6 * N + n] = l3 * l5;
0468         a[5 * N + n] = a[7 * N + n] = l4 * l5;
0469         a[8 * N + n] = l5 * l5;
0470 
0471         // m(2,x) are all zero if anything went wrong at l5.
0472         // all zero, if anything went wrong already for l0 or l2.
0473       }
0474     }
0475   };
0476 
0477   template <typename T, idx_t D, idx_t N>
0478   void invertCholesky(MPlex<T, D, D, N>& A) {
0479     CholeskyInverter<T, D, N>::invert(A);
0480   }
0481 
0482 }  // namespace Matriplex
0483 
0484 #endif