Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-12-19 04:05:13

0001 #ifndef RecoTracker_MkFitCore_src_Matriplex_Matriplex_h
0002 #define RecoTracker_MkFitCore_src_Matriplex_Matriplex_h
0003 
0004 #include "MatriplexCommon.h"
0005 
0006 namespace Matriplex {
0007 
0008   //------------------------------------------------------------------------------
0009 
0010   template <typename T, idx_t D1, idx_t D2, idx_t N>
0011   class __attribute__((aligned(MPLEX_ALIGN))) Matriplex {
0012   public:
0013     typedef T value_type;
0014 
0015     /// return no. of matrix rows
0016     static constexpr int kRows = D1;
0017     /// return no. of matrix columns
0018     static constexpr int kCols = D2;
0019     /// return no of elements: rows*columns
0020     static constexpr int kSize = D1 * D2;
0021     /// size of the whole matriplex
0022     static constexpr int kTotSize = N * kSize;
0023 
0024     T fArray[kTotSize];
0025 
0026     Matriplex() {}
0027     Matriplex(T v) { setVal(v); }
0028 
0029     idx_t plexSize() const { return N; }
0030 
0031     void setVal(T v) {
0032       for (idx_t i = 0; i < kTotSize; ++i) {
0033         fArray[i] = v;
0034       }
0035     }
0036 
0037     void add(const Matriplex& v) {
0038       for (idx_t i = 0; i < kTotSize; ++i) {
0039         fArray[i] += v.fArray[i];
0040       }
0041     }
0042 
0043     void scale(T scale) {
0044       for (idx_t i = 0; i < kTotSize; ++i) {
0045         fArray[i] *= scale;
0046       }
0047     }
0048 
0049     T operator[](idx_t xx) const { return fArray[xx]; }
0050     T& operator[](idx_t xx) { return fArray[xx]; }
0051 
0052     const T& constAt(idx_t n, idx_t i, idx_t j) const { return fArray[(i * D2 + j) * N + n]; }
0053 
0054     T& At(idx_t n, idx_t i, idx_t j) { return fArray[(i * D2 + j) * N + n]; }
0055 
0056     T& operator()(idx_t n, idx_t i, idx_t j) { return fArray[(i * D2 + j) * N + n]; }
0057     const T& operator()(idx_t n, idx_t i, idx_t j) const { return fArray[(i * D2 + j) * N + n]; }
0058 
0059     Matriplex& operator=(T t) {
0060       for (idx_t i = 0; i < kTotSize; ++i)
0061         fArray[i] = t;
0062       return *this;
0063     }
0064 
0065     Matriplex& operator+=(T t) {
0066       for (idx_t i = 0; i < kTotSize; ++i)
0067         fArray[i] += t;
0068       return *this;
0069     }
0070 
0071     Matriplex& operator-=(T t) {
0072       for (idx_t i = 0; i < kTotSize; ++i)
0073         fArray[i] -= t;
0074       return *this;
0075     }
0076 
0077     Matriplex& operator*=(T t) {
0078       for (idx_t i = 0; i < kTotSize; ++i)
0079         fArray[i] *= t;
0080       return *this;
0081     }
0082 
0083     Matriplex& operator/=(T t) {
0084       for (idx_t i = 0; i < kTotSize; ++i)
0085         fArray[i] /= t;
0086       return *this;
0087     }
0088 
0089     Matriplex& operator+=(const Matriplex& a) {
0090       for (idx_t i = 0; i < kTotSize; ++i)
0091         fArray[i] += a.fArray[i];
0092       return *this;
0093     }
0094 
0095     Matriplex& operator-=(const Matriplex& a) {
0096       for (idx_t i = 0; i < kTotSize; ++i)
0097         fArray[i] -= a.fArray[i];
0098       return *this;
0099     }
0100 
0101     Matriplex& operator*=(const Matriplex& a) {
0102       for (idx_t i = 0; i < kTotSize; ++i)
0103         fArray[i] *= a.fArray[i];
0104       return *this;
0105     }
0106 
0107     Matriplex& operator/=(const Matriplex& a) {
0108       for (idx_t i = 0; i < kTotSize; ++i)
0109         fArray[i] /= a.fArray[i];
0110       return *this;
0111     }
0112 
0113     Matriplex operator-() {
0114       Matriplex t;
0115       for (idx_t i = 0; i < kTotSize; ++i)
0116         t.fArray[i] = -fArray[i];
0117       return t;
0118     }
0119 
0120     Matriplex& abs(const Matriplex& a) {
0121       for (idx_t i = 0; i < kTotSize; ++i)
0122         fArray[i] = std::abs(a.fArray[i]);
0123       return *this;
0124     }
0125     Matriplex& abs() {
0126       for (idx_t i = 0; i < kTotSize; ++i)
0127         fArray[i] = std::abs(fArray[i]);
0128       return *this;
0129     }
0130 
0131     Matriplex& sqrt(const Matriplex& a) {
0132       for (idx_t i = 0; i < kTotSize; ++i)
0133         fArray[i] = std::sqrt(a.fArray[i]);
0134       return *this;
0135     }
0136     Matriplex& sqrt() {
0137       for (idx_t i = 0; i < kTotSize; ++i)
0138         fArray[i] = std::sqrt(fArray[i]);
0139       return *this;
0140     }
0141 
0142     Matriplex& sqr(const Matriplex& a) {
0143       for (idx_t i = 0; i < kTotSize; ++i)
0144         fArray[i] = a.fArray[i] * a.fArray[i];
0145       return *this;
0146     }
0147     Matriplex& sqr() {
0148       for (idx_t i = 0; i < kTotSize; ++i)
0149         fArray[i] = fArray[i] * fArray[i];
0150       return *this;
0151     }
0152 
0153     Matriplex& hypot(const Matriplex& a, const Matriplex& b) {
0154       for (idx_t i = 0; i < kTotSize; ++i) {
0155         fArray[i] = a.fArray[i] * a.fArray[i] + b.fArray[i] * b.fArray[i];
0156       }
0157       return sqrt();
0158     }
0159 
0160     Matriplex& sin(const Matriplex& a) {
0161       for (idx_t i = 0; i < kTotSize; ++i)
0162         fArray[i] = std::sin(a.fArray[i]);
0163       return *this;
0164     }
0165     Matriplex& sin() {
0166       for (idx_t i = 0; i < kTotSize; ++i)
0167         fArray[i] = std::sin(fArray[i]);
0168       return *this;
0169     }
0170 
0171     Matriplex& cos(const Matriplex& a) {
0172       for (idx_t i = 0; i < kTotSize; ++i)
0173         fArray[i] = std::cos(a.fArray[i]);
0174       return *this;
0175     }
0176     Matriplex& cos() {
0177       for (idx_t i = 0; i < kTotSize; ++i)
0178         fArray[i] = std::cos(fArray[i]);
0179       return *this;
0180     }
0181 
0182     Matriplex& tan(const Matriplex& a) {
0183       for (idx_t i = 0; i < kTotSize; ++i)
0184         fArray[i] = std::tan(a.fArray[i]);
0185       return *this;
0186     }
0187     Matriplex& tan() {
0188       for (idx_t i = 0; i < kTotSize; ++i)
0189         fArray[i] = std::tan(fArray[i]);
0190       return *this;
0191     }
0192 
0193     //---------------------------------------------------------
0194 
0195     void copySlot(idx_t n, const Matriplex& m) {
0196       for (idx_t i = n; i < kTotSize; i += N) {
0197         fArray[i] = m.fArray[i];
0198       }
0199     }
0200 
0201     void copyIn(idx_t n, const T* arr) {
0202       for (idx_t i = n; i < kTotSize; i += N) {
0203         fArray[i] = *(arr++);
0204       }
0205     }
0206 
0207     void copyIn(idx_t n, const Matriplex& m, idx_t in) {
0208       for (idx_t i = n; i < kTotSize; i += N, in += N) {
0209         fArray[i] = m[in];
0210       }
0211     }
0212 
0213     void copy(idx_t n, idx_t in) {
0214       for (idx_t i = n; i < kTotSize; i += N, in += N) {
0215         fArray[i] = fArray[in];
0216       }
0217     }
0218 
0219 #if defined(AVX512_INTRINSICS)
0220 
0221     template <typename U>
0222     void slurpIn(const T* arr, __m512i& vi, const U&, const int N_proc = N) {
0223       //_mm512_prefetch_i32gather_ps(vi, arr, 1, _MM_HINT_T0);
0224 
0225       const __m512 src = {0};
0226       const __mmask16 k = N_proc == N ? -1 : (1 << N_proc) - 1;
0227 
0228       for (int i = 0; i < kSize; ++i, ++arr) {
0229         //_mm512_prefetch_i32gather_ps(vi, arr+2, 1, _MM_HINT_NTA);
0230 
0231         __m512 reg = _mm512_mask_i32gather_ps(src, k, vi, arr, sizeof(U));
0232         _mm512_mask_store_ps(&fArray[i * N], k, reg);
0233       }
0234     }
0235 
0236     // Experimental methods, slurpIn() seems to be at least as fast.
0237     // See comments in mkFit/MkFitter.cc MkFitter::addBestHit().
0238     void ChewIn(const char* arr, int off, int vi[N], const char* tmp, __m512i& ui) {
0239       // This is a hack ... we know sizeof(Hit) = 64 = cache line = vector width.
0240 
0241       for (int i = 0; i < N; ++i) {
0242         __m512 reg = _mm512_load_ps(arr + vi[i]);
0243         _mm512_store_ps((void*)(tmp + 64 * i), reg);
0244       }
0245 
0246       for (int i = 0; i < kSize; ++i) {
0247         __m512 reg = _mm512_i32gather_ps(ui, tmp + off + i * sizeof(T), 1);
0248         _mm512_store_ps(&fArray[i * N], reg);
0249       }
0250     }
0251 
0252     void Contaginate(const char* arr, int vi[N], const char* tmp) {
0253       // This is a hack ... we know sizeof(Hit) = 64 = cache line = vector width.
0254 
0255       for (int i = 0; i < N; ++i) {
0256         __m512 reg = _mm512_load_ps(arr + vi[i]);
0257         _mm512_store_ps((void*)(tmp + 64 * i), reg);
0258       }
0259     }
0260 
0261     void Plexify(const char* tmp, __m512i& ui) {
0262       for (int i = 0; i < kSize; ++i) {
0263         __m512 reg = _mm512_i32gather_ps(ui, tmp + i * sizeof(T), 1);
0264         _mm512_store_ps(&fArray[i * N], reg);
0265       }
0266     }
0267 
0268 #elif defined(AVX2_INTRINSICS)
0269 
0270     template <typename U>
0271     void slurpIn(const T* arr, __m256i& vi, const U&, const int N_proc = N) {
0272       // Casts to float* needed to "support" also T=HitOnTrack.
0273       // Note that sizeof(float) == sizeof(HitOnTrack) == 4.
0274 
0275       const __m256 src = {0};
0276 
0277       __m256i k = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
0278       __m256i k_sel = _mm256_set1_epi32(N_proc);
0279       __m256i k_master = _mm256_cmpgt_epi32(k_sel, k);
0280 
0281       k = k_master;
0282       for (int i = 0; i < kSize; ++i, ++arr) {
0283         __m256 reg = _mm256_mask_i32gather_ps(src, (float*)arr, vi, (__m256)k, sizeof(U));
0284         // Restore mask (docs say gather clears it but it doesn't seem to).
0285         k = k_master;
0286         _mm256_maskstore_ps((float*)&fArray[i * N], k, reg);
0287       }
0288     }
0289 
0290 #else
0291 
0292     void slurpIn(const T* arr, int vi[N], const int N_proc = N) {
0293       // Separate N_proc == N case (gains about 7% in fit test).
0294       if (N_proc == N) {
0295         for (int i = 0; i < kSize; ++i) {
0296           for (int j = 0; j < N; ++j) {
0297             fArray[i * N + j] = *(arr + i + vi[j]);
0298           }
0299         }
0300       } else {
0301         for (int i = 0; i < kSize; ++i) {
0302           for (int j = 0; j < N_proc; ++j) {
0303             fArray[i * N + j] = *(arr + i + vi[j]);
0304           }
0305         }
0306       }
0307     }
0308 
0309 #endif
0310 
0311     void copyOut(idx_t n, T* arr) const {
0312       for (idx_t i = n; i < kTotSize; i += N) {
0313         *(arr++) = fArray[i];
0314       }
0315     }
0316 
0317     Matriplex<T, 1, 1, N> ReduceFixedIJ(idx_t i, idx_t j) const {
0318       Matriplex<T, 1, 1, N> t;
0319       for (idx_t n = 0; n < N; ++n) {
0320         t[n] = constAt(n, i, j);
0321       }
0322       return t;
0323     }
0324   };
0325 
0326   template <typename T, idx_t D1, idx_t D2, idx_t N>
0327   using MPlex = Matriplex<T, D1, D2, N>;
0328 
0329   //==============================================================================
0330   // Operators
0331   //==============================================================================
0332 
0333   template <typename T, idx_t D1, idx_t D2, idx_t N>
0334   MPlex<T, D1, D2, N> operator+(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0335     MPlex<T, D1, D2, N> t = a;
0336     t += b;
0337     return t;
0338   }
0339 
0340   template <typename T, idx_t D1, idx_t D2, idx_t N>
0341   MPlex<T, D1, D2, N> operator-(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0342     MPlex<T, D1, D2, N> t = a;
0343     t -= b;
0344     return t;
0345   }
0346 
0347   template <typename T, idx_t D1, idx_t D2, idx_t N>
0348   MPlex<T, D1, D2, N> operator*(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0349     MPlex<T, D1, D2, N> t = a;
0350     t *= b;
0351     return t;
0352   }
0353 
0354   template <typename T, idx_t D1, idx_t D2, idx_t N>
0355   MPlex<T, D1, D2, N> operator/(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0356     MPlex<T, D1, D2, N> t = a;
0357     t /= b;
0358     return t;
0359   }
0360 
0361   template <typename T, idx_t D1, idx_t D2, idx_t N>
0362   MPlex<T, D1, D2, N> operator+(const MPlex<T, D1, D2, N>& a, T b) {
0363     MPlex<T, D1, D2, N> t = a;
0364     t += b;
0365     return t;
0366   }
0367 
0368   template <typename T, idx_t D1, idx_t D2, idx_t N>
0369   MPlex<T, D1, D2, N> operator-(const MPlex<T, D1, D2, N>& a, T b) {
0370     MPlex<T, D1, D2, N> t = a;
0371     t -= b;
0372     return t;
0373   }
0374 
0375   template <typename T, idx_t D1, idx_t D2, idx_t N>
0376   MPlex<T, D1, D2, N> operator*(const MPlex<T, D1, D2, N>& a, T b) {
0377     MPlex<T, D1, D2, N> t = a;
0378     t *= b;
0379     return t;
0380   }
0381 
0382   template <typename T, idx_t D1, idx_t D2, idx_t N>
0383   MPlex<T, D1, D2, N> operator/(const MPlex<T, D1, D2, N>& a, T b) {
0384     MPlex<T, D1, D2, N> t = a;
0385     t /= b;
0386     return t;
0387   }
0388 
0389   template <typename T, idx_t D1, idx_t D2, idx_t N>
0390   MPlex<T, D1, D2, N> operator+(T a, const MPlex<T, D1, D2, N>& b) {
0391     MPlex<T, D1, D2, N> t = a;
0392     t += b;
0393     return t;
0394   }
0395 
0396   template <typename T, idx_t D1, idx_t D2, idx_t N>
0397   MPlex<T, D1, D2, N> operator-(T a, const MPlex<T, D1, D2, N>& b) {
0398     MPlex<T, D1, D2, N> t = a;
0399     t -= b;
0400     return t;
0401   }
0402 
0403   template <typename T, idx_t D1, idx_t D2, idx_t N>
0404   MPlex<T, D1, D2, N> operator*(T a, const MPlex<T, D1, D2, N>& b) {
0405     MPlex<T, D1, D2, N> t = a;
0406     t *= b;
0407     return t;
0408   }
0409 
0410   template <typename T, idx_t D1, idx_t D2, idx_t N>
0411   MPlex<T, D1, D2, N> operator/(T a, const MPlex<T, D1, D2, N>& b) {
0412     MPlex<T, D1, D2, N> t = a;
0413     t /= b;
0414     return t;
0415   }
0416 
0417   template <typename T, idx_t D1, idx_t D2, idx_t N>
0418   MPlex<T, D1, D2, N> abs(const MPlex<T, D1, D2, N>& a) {
0419     MPlex<T, D1, D2, N> t;
0420     return t.abs(a);
0421   }
0422 
0423   template <typename T, idx_t D1, idx_t D2, idx_t N>
0424   MPlex<T, D1, D2, N> sqrt(const MPlex<T, D1, D2, N>& a) {
0425     MPlex<T, D1, D2, N> t;
0426     return t.sqrt(a);
0427   }
0428 
0429   template <typename T, idx_t D1, idx_t D2, idx_t N>
0430   MPlex<T, D1, D2, N> sqr(const MPlex<T, D1, D2, N>& a) {
0431     MPlex<T, D1, D2, N> t;
0432     return t.sqr(a);
0433   }
0434 
0435   template <typename T, idx_t D1, idx_t D2, idx_t N>
0436   MPlex<T, D1, D2, N> hypot(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0437     MPlex<T, D1, D2, N> t;
0438     return t.hypot(a, b);
0439   }
0440 
0441   template <typename T, idx_t D1, idx_t D2, idx_t N>
0442   MPlex<T, D1, D2, N> sin(const MPlex<T, D1, D2, N>& a) {
0443     MPlex<T, D1, D2, N> t;
0444     return t.sin(a);
0445   }
0446 
0447   template <typename T, idx_t D1, idx_t D2, idx_t N>
0448   MPlex<T, D1, D2, N> cos(const MPlex<T, D1, D2, N>& a) {
0449     MPlex<T, D1, D2, N> t;
0450     return t.cos(a);
0451   }
0452 
0453   template <typename T, idx_t D1, idx_t D2, idx_t N>
0454   void sincos(const MPlex<T, D1, D2, N>& a, MPlex<T, D1, D2, N>& s, MPlex<T, D1, D2, N>& c) {
0455     for (idx_t i = 0; i < a.kTotSize; ++i) {
0456       s.fArray[i] = std::sin(a.fArray[i]);
0457       c.fArray[i] = std::cos(a.fArray[i]);
0458     }
0459   }
0460 
0461   template <typename T, idx_t D1, idx_t D2, idx_t N>
0462   MPlex<T, D1, D2, N> tan(const MPlex<T, D1, D2, N>& a) {
0463     MPlex<T, D1, D2, N> t;
0464     return t.tan(a);
0465   }
0466 
0467   template <typename T, idx_t D1, idx_t D2, idx_t N>
0468   void min_max(const MPlex<T, D1, D2, N>& a,
0469                const MPlex<T, D1, D2, N>& b,
0470                MPlex<T, D1, D2, N>& min,
0471                MPlex<T, D1, D2, N>& max) {
0472     for (idx_t i = 0; i < a.kTotSize; ++i) {
0473       min.fArray[i] = std::min(a.fArray[i], b.fArray[i]);
0474       max.fArray[i] = std::max(a.fArray[i], b.fArray[i]);
0475     }
0476   }
0477 
0478   template <typename T, idx_t D1, idx_t D2, idx_t N>
0479   MPlex<T, D1, D2, N> min(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0480     MPlex<T, D1, D2, N> t;
0481     for (idx_t i = 0; i < a.kTotSize; ++i) {
0482       t.fArray[i] = std::min(a.fArray[i], b.fArray[i]);
0483     }
0484     return t;
0485   }
0486 
0487   template <typename T, idx_t D1, idx_t D2, idx_t N>
0488   MPlex<T, D1, D2, N> max(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0489     MPlex<T, D1, D2, N> t;
0490     for (idx_t i = 0; i < a.kTotSize; ++i) {
0491       t.fArray[i] = std::max(a.fArray[i], b.fArray[i]);
0492     }
0493     return t;
0494   }
0495 
0496   //==============================================================================
0497   // Multiplications
0498   //==============================================================================
0499 
0500   template <typename T, idx_t D1, idx_t D2, idx_t D3, idx_t N>
0501   void multiplyGeneral(const MPlex<T, D1, D2, N>& A, const MPlex<T, D2, D3, N>& B, MPlex<T, D1, D3, N>& C) {
0502     for (idx_t i = 0; i < D1; ++i) {
0503       for (idx_t j = 0; j < D3; ++j) {
0504         const idx_t ijo = N * (i * D3 + j);
0505 
0506 #pragma omp simd
0507         for (idx_t n = 0; n < N; ++n) {
0508           C.fArray[ijo + n] = 0;
0509         }
0510 
0511         for (idx_t k = 0; k < D2; ++k) {
0512           const idx_t iko = N * (i * D2 + k);
0513           const idx_t kjo = N * (k * D3 + j);
0514 
0515 #pragma omp simd
0516           for (idx_t n = 0; n < N; ++n) {
0517             C.fArray[ijo + n] += A.fArray[iko + n] * B.fArray[kjo + n];
0518           }
0519         }
0520       }
0521     }
0522   }
0523 
0524   //------------------------------------------------------------------------------
0525 
0526   template <typename T, idx_t D, idx_t N>
0527   struct MultiplyCls {
0528     static void multiply(const MPlex<T, D, D, N>& A, const MPlex<T, D, D, N>& B, MPlex<T, D, D, N>& C) {
0529       throw std::runtime_error("general multiplication not supported, well, call multiplyGeneral()");
0530     }
0531   };
0532 
0533   template <typename T, idx_t N>
0534   struct MultiplyCls<T, 3, N> {
0535     static void multiply(const MPlex<T, 3, 3, N>& A, const MPlex<T, 3, 3, N>& B, MPlex<T, 3, 3, N>& C) {
0536       const T* a = A.fArray;
0537       ASSUME_ALIGNED(a, 64);
0538       const T* b = B.fArray;
0539       ASSUME_ALIGNED(b, 64);
0540       T* c = C.fArray;
0541       ASSUME_ALIGNED(c, 64);
0542 
0543 #pragma omp simd
0544       for (idx_t n = 0; n < N; ++n) {
0545         c[0 * N + n] = a[0 * N + n] * b[0 * N + n] + a[1 * N + n] * b[3 * N + n] + a[2 * N + n] * b[6 * N + n];
0546         c[1 * N + n] = a[0 * N + n] * b[1 * N + n] + a[1 * N + n] * b[4 * N + n] + a[2 * N + n] * b[7 * N + n];
0547         c[2 * N + n] = a[0 * N + n] * b[2 * N + n] + a[1 * N + n] * b[5 * N + n] + a[2 * N + n] * b[8 * N + n];
0548         c[3 * N + n] = a[3 * N + n] * b[0 * N + n] + a[4 * N + n] * b[3 * N + n] + a[5 * N + n] * b[6 * N + n];
0549         c[4 * N + n] = a[3 * N + n] * b[1 * N + n] + a[4 * N + n] * b[4 * N + n] + a[5 * N + n] * b[7 * N + n];
0550         c[5 * N + n] = a[3 * N + n] * b[2 * N + n] + a[4 * N + n] * b[5 * N + n] + a[5 * N + n] * b[8 * N + n];
0551         c[6 * N + n] = a[6 * N + n] * b[0 * N + n] + a[7 * N + n] * b[3 * N + n] + a[8 * N + n] * b[6 * N + n];
0552         c[7 * N + n] = a[6 * N + n] * b[1 * N + n] + a[7 * N + n] * b[4 * N + n] + a[8 * N + n] * b[7 * N + n];
0553         c[8 * N + n] = a[6 * N + n] * b[2 * N + n] + a[7 * N + n] * b[5 * N + n] + a[8 * N + n] * b[8 * N + n];
0554       }
0555     }
0556   };
0557 
0558   template <typename T, idx_t N>
0559   struct MultiplyCls<T, 6, N> {
0560     static void multiply(const MPlex<T, 6, 6, N>& A, const MPlex<T, 6, 6, N>& B, MPlex<T, 6, 6, N>& C) {
0561       const T* a = A.fArray;
0562       ASSUME_ALIGNED(a, 64);
0563       const T* b = B.fArray;
0564       ASSUME_ALIGNED(b, 64);
0565       T* c = C.fArray;
0566       ASSUME_ALIGNED(c, 64);
0567 #pragma omp simd
0568       for (idx_t n = 0; n < N; ++n) {
0569         c[0 * N + n] = a[0 * N + n] * b[0 * N + n] + a[1 * N + n] * b[6 * N + n] + a[2 * N + n] * b[12 * N + n] +
0570                        a[3 * N + n] * b[18 * N + n] + a[4 * N + n] * b[24 * N + n] + a[5 * N + n] * b[30 * N + n];
0571         c[1 * N + n] = a[0 * N + n] * b[1 * N + n] + a[1 * N + n] * b[7 * N + n] + a[2 * N + n] * b[13 * N + n] +
0572                        a[3 * N + n] * b[19 * N + n] + a[4 * N + n] * b[25 * N + n] + a[5 * N + n] * b[31 * N + n];
0573         c[2 * N + n] = a[0 * N + n] * b[2 * N + n] + a[1 * N + n] * b[8 * N + n] + a[2 * N + n] * b[14 * N + n] +
0574                        a[3 * N + n] * b[20 * N + n] + a[4 * N + n] * b[26 * N + n] + a[5 * N + n] * b[32 * N + n];
0575         c[3 * N + n] = a[0 * N + n] * b[3 * N + n] + a[1 * N + n] * b[9 * N + n] + a[2 * N + n] * b[15 * N + n] +
0576                        a[3 * N + n] * b[21 * N + n] + a[4 * N + n] * b[27 * N + n] + a[5 * N + n] * b[33 * N + n];
0577         c[4 * N + n] = a[0 * N + n] * b[4 * N + n] + a[1 * N + n] * b[10 * N + n] + a[2 * N + n] * b[16 * N + n] +
0578                        a[3 * N + n] * b[22 * N + n] + a[4 * N + n] * b[28 * N + n] + a[5 * N + n] * b[34 * N + n];
0579         c[5 * N + n] = a[0 * N + n] * b[5 * N + n] + a[1 * N + n] * b[11 * N + n] + a[2 * N + n] * b[17 * N + n] +
0580                        a[3 * N + n] * b[23 * N + n] + a[4 * N + n] * b[29 * N + n] + a[5 * N + n] * b[35 * N + n];
0581         c[6 * N + n] = a[6 * N + n] * b[0 * N + n] + a[7 * N + n] * b[6 * N + n] + a[8 * N + n] * b[12 * N + n] +
0582                        a[9 * N + n] * b[18 * N + n] + a[10 * N + n] * b[24 * N + n] + a[11 * N + n] * b[30 * N + n];
0583         c[7 * N + n] = a[6 * N + n] * b[1 * N + n] + a[7 * N + n] * b[7 * N + n] + a[8 * N + n] * b[13 * N + n] +
0584                        a[9 * N + n] * b[19 * N + n] + a[10 * N + n] * b[25 * N + n] + a[11 * N + n] * b[31 * N + n];
0585         c[8 * N + n] = a[6 * N + n] * b[2 * N + n] + a[7 * N + n] * b[8 * N + n] + a[8 * N + n] * b[14 * N + n] +
0586                        a[9 * N + n] * b[20 * N + n] + a[10 * N + n] * b[26 * N + n] + a[11 * N + n] * b[32 * N + n];
0587         c[9 * N + n] = a[6 * N + n] * b[3 * N + n] + a[7 * N + n] * b[9 * N + n] + a[8 * N + n] * b[15 * N + n] +
0588                        a[9 * N + n] * b[21 * N + n] + a[10 * N + n] * b[27 * N + n] + a[11 * N + n] * b[33 * N + n];
0589         c[10 * N + n] = a[6 * N + n] * b[4 * N + n] + a[7 * N + n] * b[10 * N + n] + a[8 * N + n] * b[16 * N + n] +
0590                         a[9 * N + n] * b[22 * N + n] + a[10 * N + n] * b[28 * N + n] + a[11 * N + n] * b[34 * N + n];
0591         c[11 * N + n] = a[6 * N + n] * b[5 * N + n] + a[7 * N + n] * b[11 * N + n] + a[8 * N + n] * b[17 * N + n] +
0592                         a[9 * N + n] * b[23 * N + n] + a[10 * N + n] * b[29 * N + n] + a[11 * N + n] * b[35 * N + n];
0593         c[12 * N + n] = a[12 * N + n] * b[0 * N + n] + a[13 * N + n] * b[6 * N + n] + a[14 * N + n] * b[12 * N + n] +
0594                         a[15 * N + n] * b[18 * N + n] + a[16 * N + n] * b[24 * N + n] + a[17 * N + n] * b[30 * N + n];
0595         c[13 * N + n] = a[12 * N + n] * b[1 * N + n] + a[13 * N + n] * b[7 * N + n] + a[14 * N + n] * b[13 * N + n] +
0596                         a[15 * N + n] * b[19 * N + n] + a[16 * N + n] * b[25 * N + n] + a[17 * N + n] * b[31 * N + n];
0597         c[14 * N + n] = a[12 * N + n] * b[2 * N + n] + a[13 * N + n] * b[8 * N + n] + a[14 * N + n] * b[14 * N + n] +
0598                         a[15 * N + n] * b[20 * N + n] + a[16 * N + n] * b[26 * N + n] + a[17 * N + n] * b[32 * N + n];
0599         c[15 * N + n] = a[12 * N + n] * b[3 * N + n] + a[13 * N + n] * b[9 * N + n] + a[14 * N + n] * b[15 * N + n] +
0600                         a[15 * N + n] * b[21 * N + n] + a[16 * N + n] * b[27 * N + n] + a[17 * N + n] * b[33 * N + n];
0601         c[16 * N + n] = a[12 * N + n] * b[4 * N + n] + a[13 * N + n] * b[10 * N + n] + a[14 * N + n] * b[16 * N + n] +
0602                         a[15 * N + n] * b[22 * N + n] + a[16 * N + n] * b[28 * N + n] + a[17 * N + n] * b[34 * N + n];
0603         c[17 * N + n] = a[12 * N + n] * b[5 * N + n] + a[13 * N + n] * b[11 * N + n] + a[14 * N + n] * b[17 * N + n] +
0604                         a[15 * N + n] * b[23 * N + n] + a[16 * N + n] * b[29 * N + n] + a[17 * N + n] * b[35 * N + n];
0605         c[18 * N + n] = a[18 * N + n] * b[0 * N + n] + a[19 * N + n] * b[6 * N + n] + a[20 * N + n] * b[12 * N + n] +
0606                         a[21 * N + n] * b[18 * N + n] + a[22 * N + n] * b[24 * N + n] + a[23 * N + n] * b[30 * N + n];
0607         c[19 * N + n] = a[18 * N + n] * b[1 * N + n] + a[19 * N + n] * b[7 * N + n] + a[20 * N + n] * b[13 * N + n] +
0608                         a[21 * N + n] * b[19 * N + n] + a[22 * N + n] * b[25 * N + n] + a[23 * N + n] * b[31 * N + n];
0609         c[20 * N + n] = a[18 * N + n] * b[2 * N + n] + a[19 * N + n] * b[8 * N + n] + a[20 * N + n] * b[14 * N + n] +
0610                         a[21 * N + n] * b[20 * N + n] + a[22 * N + n] * b[26 * N + n] + a[23 * N + n] * b[32 * N + n];
0611         c[21 * N + n] = a[18 * N + n] * b[3 * N + n] + a[19 * N + n] * b[9 * N + n] + a[20 * N + n] * b[15 * N + n] +
0612                         a[21 * N + n] * b[21 * N + n] + a[22 * N + n] * b[27 * N + n] + a[23 * N + n] * b[33 * N + n];
0613         c[22 * N + n] = a[18 * N + n] * b[4 * N + n] + a[19 * N + n] * b[10 * N + n] + a[20 * N + n] * b[16 * N + n] +
0614                         a[21 * N + n] * b[22 * N + n] + a[22 * N + n] * b[28 * N + n] + a[23 * N + n] * b[34 * N + n];
0615         c[23 * N + n] = a[18 * N + n] * b[5 * N + n] + a[19 * N + n] * b[11 * N + n] + a[20 * N + n] * b[17 * N + n] +
0616                         a[21 * N + n] * b[23 * N + n] + a[22 * N + n] * b[29 * N + n] + a[23 * N + n] * b[35 * N + n];
0617         c[24 * N + n] = a[24 * N + n] * b[0 * N + n] + a[25 * N + n] * b[6 * N + n] + a[26 * N + n] * b[12 * N + n] +
0618                         a[27 * N + n] * b[18 * N + n] + a[28 * N + n] * b[24 * N + n] + a[29 * N + n] * b[30 * N + n];
0619         c[25 * N + n] = a[24 * N + n] * b[1 * N + n] + a[25 * N + n] * b[7 * N + n] + a[26 * N + n] * b[13 * N + n] +
0620                         a[27 * N + n] * b[19 * N + n] + a[28 * N + n] * b[25 * N + n] + a[29 * N + n] * b[31 * N + n];
0621         c[26 * N + n] = a[24 * N + n] * b[2 * N + n] + a[25 * N + n] * b[8 * N + n] + a[26 * N + n] * b[14 * N + n] +
0622                         a[27 * N + n] * b[20 * N + n] + a[28 * N + n] * b[26 * N + n] + a[29 * N + n] * b[32 * N + n];
0623         c[27 * N + n] = a[24 * N + n] * b[3 * N + n] + a[25 * N + n] * b[9 * N + n] + a[26 * N + n] * b[15 * N + n] +
0624                         a[27 * N + n] * b[21 * N + n] + a[28 * N + n] * b[27 * N + n] + a[29 * N + n] * b[33 * N + n];
0625         c[28 * N + n] = a[24 * N + n] * b[4 * N + n] + a[25 * N + n] * b[10 * N + n] + a[26 * N + n] * b[16 * N + n] +
0626                         a[27 * N + n] * b[22 * N + n] + a[28 * N + n] * b[28 * N + n] + a[29 * N + n] * b[34 * N + n];
0627         c[29 * N + n] = a[24 * N + n] * b[5 * N + n] + a[25 * N + n] * b[11 * N + n] + a[26 * N + n] * b[17 * N + n] +
0628                         a[27 * N + n] * b[23 * N + n] + a[28 * N + n] * b[29 * N + n] + a[29 * N + n] * b[35 * N + n];
0629         c[30 * N + n] = a[30 * N + n] * b[0 * N + n] + a[31 * N + n] * b[6 * N + n] + a[32 * N + n] * b[12 * N + n] +
0630                         a[33 * N + n] * b[18 * N + n] + a[34 * N + n] * b[24 * N + n] + a[35 * N + n] * b[30 * N + n];
0631         c[31 * N + n] = a[30 * N + n] * b[1 * N + n] + a[31 * N + n] * b[7 * N + n] + a[32 * N + n] * b[13 * N + n] +
0632                         a[33 * N + n] * b[19 * N + n] + a[34 * N + n] * b[25 * N + n] + a[35 * N + n] * b[31 * N + n];
0633         c[32 * N + n] = a[30 * N + n] * b[2 * N + n] + a[31 * N + n] * b[8 * N + n] + a[32 * N + n] * b[14 * N + n] +
0634                         a[33 * N + n] * b[20 * N + n] + a[34 * N + n] * b[26 * N + n] + a[35 * N + n] * b[32 * N + n];
0635         c[33 * N + n] = a[30 * N + n] * b[3 * N + n] + a[31 * N + n] * b[9 * N + n] + a[32 * N + n] * b[15 * N + n] +
0636                         a[33 * N + n] * b[21 * N + n] + a[34 * N + n] * b[27 * N + n] + a[35 * N + n] * b[33 * N + n];
0637         c[34 * N + n] = a[30 * N + n] * b[4 * N + n] + a[31 * N + n] * b[10 * N + n] + a[32 * N + n] * b[16 * N + n] +
0638                         a[33 * N + n] * b[22 * N + n] + a[34 * N + n] * b[28 * N + n] + a[35 * N + n] * b[34 * N + n];
0639         c[35 * N + n] = a[30 * N + n] * b[5 * N + n] + a[31 * N + n] * b[11 * N + n] + a[32 * N + n] * b[17 * N + n] +
0640                         a[33 * N + n] * b[23 * N + n] + a[34 * N + n] * b[29 * N + n] + a[35 * N + n] * b[35 * N + n];
0641       }
0642     }
0643   };
0644 
0645   template <typename T, idx_t D, idx_t N>
0646   void multiply(const MPlex<T, D, D, N>& A, const MPlex<T, D, D, N>& B, MPlex<T, D, D, N>& C) {
0647 #ifdef DEBUG
0648     printf("Multipl %d %d\n", D, N);
0649 #endif
0650 
0651     MultiplyCls<T, D, N>::multiply(A, B, C);
0652   }
0653 
0654   //==============================================================================
0655   // Cramer inversion
0656   //==============================================================================
0657 
0658   template <typename T, idx_t D, idx_t N>
0659   struct CramerInverter {
0660     static void invert(MPlex<T, D, D, N>& A, double* determ = nullptr) {
0661       throw std::runtime_error("general cramer inversion not supported");
0662     }
0663   };
0664 
0665   template <typename T, idx_t N>
0666   struct CramerInverter<T, 2, N> {
0667     static void invert(MPlex<T, 2, 2, N>& A, double* determ = nullptr) {
0668       typedef T TT;
0669 
0670       T* a = A.fArray;
0671       ASSUME_ALIGNED(a, 64);
0672 
0673 #pragma omp simd
0674       for (idx_t n = 0; n < N; ++n) {
0675         // Force determinant calculation in double precision.
0676         const double det = (double)a[0 * N + n] * a[3 * N + n] - (double)a[2 * N + n] * a[1 * N + n];
0677         if (determ)
0678           determ[n] = det;
0679 
0680         const TT s = TT(1) / det;
0681         const TT tmp = s * a[3 * N + n];
0682         a[1 * N + n] *= -s;
0683         a[2 * N + n] *= -s;
0684         a[3 * N + n] = s * a[0 * N + n];
0685         a[0 * N + n] = tmp;
0686       }
0687     }
0688   };
0689 
0690   template <typename T, idx_t N>
0691   struct CramerInverter<T, 3, N> {
0692     static void invert(MPlex<T, 3, 3, N>& A, double* determ = nullptr) {
0693       typedef T TT;
0694 
0695       T* a = A.fArray;
0696       ASSUME_ALIGNED(a, 64);
0697 
0698 #pragma omp simd
0699       for (idx_t n = 0; n < N; ++n) {
0700         const TT c00 = a[4 * N + n] * a[8 * N + n] - a[5 * N + n] * a[7 * N + n];
0701         const TT c01 = a[5 * N + n] * a[6 * N + n] - a[3 * N + n] * a[8 * N + n];
0702         const TT c02 = a[3 * N + n] * a[7 * N + n] - a[4 * N + n] * a[6 * N + n];
0703         const TT c10 = a[7 * N + n] * a[2 * N + n] - a[8 * N + n] * a[1 * N + n];
0704         const TT c11 = a[8 * N + n] * a[0 * N + n] - a[6 * N + n] * a[2 * N + n];
0705         const TT c12 = a[6 * N + n] * a[1 * N + n] - a[7 * N + n] * a[0 * N + n];
0706         const TT c20 = a[1 * N + n] * a[5 * N + n] - a[2 * N + n] * a[4 * N + n];
0707         const TT c21 = a[2 * N + n] * a[3 * N + n] - a[0 * N + n] * a[5 * N + n];
0708         const TT c22 = a[0 * N + n] * a[4 * N + n] - a[1 * N + n] * a[3 * N + n];
0709 
0710         // Force determinant calculation in double precision.
0711         const double det = (double)a[0 * N + n] * c00 + (double)a[1 * N + n] * c01 + (double)a[2 * N + n] * c02;
0712         if (determ)
0713           determ[n] = det;
0714 
0715         const TT s = TT(1) / det;
0716         a[0 * N + n] = s * c00;
0717         a[1 * N + n] = s * c10;
0718         a[2 * N + n] = s * c20;
0719         a[3 * N + n] = s * c01;
0720         a[4 * N + n] = s * c11;
0721         a[5 * N + n] = s * c21;
0722         a[6 * N + n] = s * c02;
0723         a[7 * N + n] = s * c12;
0724         a[8 * N + n] = s * c22;
0725       }
0726     }
0727   };
0728 
0729   template <typename T, idx_t D, idx_t N>
0730   void invertCramer(MPlex<T, D, D, N>& A, double* determ = nullptr) {
0731     CramerInverter<T, D, N>::invert(A, determ);
0732   }
0733 
0734   //==============================================================================
0735   // Cholesky inversion
0736   //==============================================================================
0737 
0738   template <typename T, idx_t D, idx_t N>
0739   struct CholeskyInverter {
0740     static void invert(MPlex<T, D, D, N>& A) { throw std::runtime_error("general cholesky inversion not supported"); }
0741   };
0742 
0743   template <typename T, idx_t N>
0744   struct CholeskyInverter<T, 3, N> {
0745     // Note: this only works on symmetric matrices.
0746     // Optimized version for positive definite matrices, no checks.
0747     static void invert(MPlex<T, 3, 3, N>& A) {
0748       typedef T TT;
0749 
0750       T* a = A.fArray;
0751       ASSUME_ALIGNED(a, 64);
0752 
0753 #pragma omp simd
0754       for (idx_t n = 0; n < N; ++n) {
0755         TT l0 = std::sqrt(T(1) / a[0 * N + n]);
0756         TT l1 = a[3 * N + n] * l0;
0757         TT l2 = a[4 * N + n] - l1 * l1;
0758         l2 = std::sqrt(T(1) / l2);
0759         TT l3 = a[6 * N + n] * l0;
0760         TT l4 = (a[7 * N + n] - l1 * l3) * l2;
0761         TT l5 = a[8 * N + n] - (l3 * l3 + l4 * l4);
0762         l5 = std::sqrt(T(1) / l5);
0763 
0764         // decomposition done
0765 
0766         l3 = (l1 * l4 * l2 - l3) * l0 * l5;
0767         l1 = -l1 * l0 * l2;
0768         l4 = -l4 * l2 * l5;
0769 
0770         a[0 * N + n] = l3 * l3 + l1 * l1 + l0 * l0;
0771         a[1 * N + n] = a[3 * N + n] = l3 * l4 + l1 * l2;
0772         a[4 * N + n] = l4 * l4 + l2 * l2;
0773         a[2 * N + n] = a[6 * N + n] = l3 * l5;
0774         a[5 * N + n] = a[7 * N + n] = l4 * l5;
0775         a[8 * N + n] = l5 * l5;
0776 
0777         // m(2,x) are all zero if anything went wrong at l5.
0778         // all zero, if anything went wrong already for l0 or l2.
0779       }
0780     }
0781   };
0782 
0783   template <typename T, idx_t D, idx_t N>
0784   void invertCholesky(MPlex<T, D, D, N>& A) {
0785     CholeskyInverter<T, D, N>::invert(A);
0786   }
0787 
0788 }  // namespace Matriplex
0789 
0790 #endif