File indexing completed on 2024-04-06 12:28:18
0001 #ifndef RecoTracker_MkFitCore_src_Matriplex_Matriplex_h
0002 #define RecoTracker_MkFitCore_src_Matriplex_Matriplex_h
0003
0004 #include "MatriplexCommon.h"
0005
0006 namespace Matriplex {
0007
0008
0009
0010 template <typename T, idx_t D1, idx_t D2, idx_t N>
0011 class __attribute__((aligned(MPLEX_ALIGN))) Matriplex {
0012 public:
0013 typedef T value_type;
0014
0015
0016 static constexpr int kRows = D1;
0017
0018 static constexpr int kCols = D2;
0019
0020 static constexpr int kSize = D1 * D2;
0021
0022 static constexpr int kTotSize = N * kSize;
0023
0024 T fArray[kTotSize];
0025
0026 Matriplex() {}
0027 Matriplex(T v) { setVal(v); }
0028
0029 idx_t plexSize() const { return N; }
0030
0031 void setVal(T v) {
0032 for (idx_t i = 0; i < kTotSize; ++i) {
0033 fArray[i] = v;
0034 }
0035 }
0036
0037 void add(const Matriplex& v) {
0038 for (idx_t i = 0; i < kTotSize; ++i) {
0039 fArray[i] += v.fArray[i];
0040 }
0041 }
0042
0043 void scale(T scale) {
0044 for (idx_t i = 0; i < kTotSize; ++i) {
0045 fArray[i] *= scale;
0046 }
0047 }
0048
0049 T operator[](idx_t xx) const { return fArray[xx]; }
0050 T& operator[](idx_t xx) { return fArray[xx]; }
0051
0052 const T& constAt(idx_t n, idx_t i, idx_t j) const { return fArray[(i * D2 + j) * N + n]; }
0053
0054 T& At(idx_t n, idx_t i, idx_t j) { return fArray[(i * D2 + j) * N + n]; }
0055
0056 T& operator()(idx_t n, idx_t i, idx_t j) { return fArray[(i * D2 + j) * N + n]; }
0057 const T& operator()(idx_t n, idx_t i, idx_t j) const { return fArray[(i * D2 + j) * N + n]; }
0058
0059 Matriplex& operator=(T t) {
0060 for (idx_t i = 0; i < kTotSize; ++i)
0061 fArray[i] = t;
0062 return *this;
0063 }
0064
0065 Matriplex& operator+=(T t) {
0066 for (idx_t i = 0; i < kTotSize; ++i)
0067 fArray[i] += t;
0068 return *this;
0069 }
0070
0071 Matriplex& operator-=(T t) {
0072 for (idx_t i = 0; i < kTotSize; ++i)
0073 fArray[i] -= t;
0074 return *this;
0075 }
0076
0077 Matriplex& operator*=(T t) {
0078 for (idx_t i = 0; i < kTotSize; ++i)
0079 fArray[i] *= t;
0080 return *this;
0081 }
0082
0083 Matriplex& operator/=(T t) {
0084 for (idx_t i = 0; i < kTotSize; ++i)
0085 fArray[i] /= t;
0086 return *this;
0087 }
0088
0089 Matriplex& operator+=(const Matriplex& a) {
0090 for (idx_t i = 0; i < kTotSize; ++i)
0091 fArray[i] += a.fArray[i];
0092 return *this;
0093 }
0094
0095 Matriplex& operator-=(const Matriplex& a) {
0096 for (idx_t i = 0; i < kTotSize; ++i)
0097 fArray[i] -= a.fArray[i];
0098 return *this;
0099 }
0100
0101 Matriplex& operator*=(const Matriplex& a) {
0102 for (idx_t i = 0; i < kTotSize; ++i)
0103 fArray[i] *= a.fArray[i];
0104 return *this;
0105 }
0106
0107 Matriplex& operator/=(const Matriplex& a) {
0108 for (idx_t i = 0; i < kTotSize; ++i)
0109 fArray[i] /= a.fArray[i];
0110 return *this;
0111 }
0112
0113 Matriplex operator-() {
0114 Matriplex t;
0115 for (idx_t i = 0; i < kTotSize; ++i)
0116 t.fArray[i] = -fArray[i];
0117 return t;
0118 }
0119
0120 Matriplex& abs(const Matriplex& a) {
0121 for (idx_t i = 0; i < kTotSize; ++i)
0122 fArray[i] = std::abs(a.fArray[i]);
0123 return *this;
0124 }
0125 Matriplex& abs() {
0126 for (idx_t i = 0; i < kTotSize; ++i)
0127 fArray[i] = std::abs(fArray[i]);
0128 return *this;
0129 }
0130
0131 Matriplex& sqrt(const Matriplex& a) {
0132 for (idx_t i = 0; i < kTotSize; ++i)
0133 fArray[i] = std::sqrt(a.fArray[i]);
0134 return *this;
0135 }
0136 Matriplex& sqrt() {
0137 for (idx_t i = 0; i < kTotSize; ++i)
0138 fArray[i] = std::sqrt(fArray[i]);
0139 return *this;
0140 }
0141
0142 Matriplex& sqr(const Matriplex& a) {
0143 for (idx_t i = 0; i < kTotSize; ++i)
0144 fArray[i] = a.fArray[i] * a.fArray[i];
0145 return *this;
0146 }
0147 Matriplex& sqr() {
0148 for (idx_t i = 0; i < kTotSize; ++i)
0149 fArray[i] = fArray[i] * fArray[i];
0150 return *this;
0151 }
0152
0153 Matriplex& hypot(const Matriplex& a, const Matriplex& b) {
0154 for (idx_t i = 0; i < kTotSize; ++i) {
0155 fArray[i] = a.fArray[i] * a.fArray[i] + b.fArray[i] * b.fArray[i];
0156 }
0157 return sqrt();
0158 }
0159
0160 Matriplex& sin(const Matriplex& a) {
0161 for (idx_t i = 0; i < kTotSize; ++i)
0162 fArray[i] = std::sin(a.fArray[i]);
0163 return *this;
0164 }
0165 Matriplex& sin() {
0166 for (idx_t i = 0; i < kTotSize; ++i)
0167 fArray[i] = std::sin(fArray[i]);
0168 return *this;
0169 }
0170
0171 Matriplex& cos(const Matriplex& a) {
0172 for (idx_t i = 0; i < kTotSize; ++i)
0173 fArray[i] = std::cos(a.fArray[i]);
0174 return *this;
0175 }
0176 Matriplex& cos() {
0177 for (idx_t i = 0; i < kTotSize; ++i)
0178 fArray[i] = std::cos(fArray[i]);
0179 return *this;
0180 }
0181
0182 Matriplex& tan(const Matriplex& a) {
0183 for (idx_t i = 0; i < kTotSize; ++i)
0184 fArray[i] = std::tan(a.fArray[i]);
0185 return *this;
0186 }
0187 Matriplex& tan() {
0188 for (idx_t i = 0; i < kTotSize; ++i)
0189 fArray[i] = std::tan(fArray[i]);
0190 return *this;
0191 }
0192
0193
0194
0195 void copySlot(idx_t n, const Matriplex& m) {
0196 for (idx_t i = n; i < kTotSize; i += N) {
0197 fArray[i] = m.fArray[i];
0198 }
0199 }
0200
0201 void copyIn(idx_t n, const T* arr) {
0202 for (idx_t i = n; i < kTotSize; i += N) {
0203 fArray[i] = *(arr++);
0204 }
0205 }
0206
0207 void copyIn(idx_t n, const Matriplex& m, idx_t in) {
0208 for (idx_t i = n; i < kTotSize; i += N, in += N) {
0209 fArray[i] = m[in];
0210 }
0211 }
0212
0213 void copy(idx_t n, idx_t in) {
0214 for (idx_t i = n; i < kTotSize; i += N, in += N) {
0215 fArray[i] = fArray[in];
0216 }
0217 }
0218
0219 #if defined(AVX512_INTRINSICS)
0220
0221 template <typename U>
0222 void slurpIn(const T* arr, __m512i& vi, const U&, const int N_proc = N) {
0223
0224
0225 const __m512 src = {0};
0226 const __mmask16 k = N_proc == N ? -1 : (1 << N_proc) - 1;
0227
0228 for (int i = 0; i < kSize; ++i, ++arr) {
0229
0230
0231 __m512 reg = _mm512_mask_i32gather_ps(src, k, vi, arr, sizeof(U));
0232 _mm512_mask_store_ps(&fArray[i * N], k, reg);
0233 }
0234 }
0235
0236
0237
0238 void ChewIn(const char* arr, int off, int vi[N], const char* tmp, __m512i& ui) {
0239
0240
0241 for (int i = 0; i < N; ++i) {
0242 __m512 reg = _mm512_load_ps(arr + vi[i]);
0243 _mm512_store_ps((void*)(tmp + 64 * i), reg);
0244 }
0245
0246 for (int i = 0; i < kSize; ++i) {
0247 __m512 reg = _mm512_i32gather_ps(ui, tmp + off + i * sizeof(T), 1);
0248 _mm512_store_ps(&fArray[i * N], reg);
0249 }
0250 }
0251
0252 void Contaginate(const char* arr, int vi[N], const char* tmp) {
0253
0254
0255 for (int i = 0; i < N; ++i) {
0256 __m512 reg = _mm512_load_ps(arr + vi[i]);
0257 _mm512_store_ps((void*)(tmp + 64 * i), reg);
0258 }
0259 }
0260
0261 void Plexify(const char* tmp, __m512i& ui) {
0262 for (int i = 0; i < kSize; ++i) {
0263 __m512 reg = _mm512_i32gather_ps(ui, tmp + i * sizeof(T), 1);
0264 _mm512_store_ps(&fArray[i * N], reg);
0265 }
0266 }
0267
0268 #elif defined(AVX2_INTRINSICS)
0269
0270 template <typename U>
0271 void slurpIn(const T* arr, __m256i& vi, const U&, const int N_proc = N) {
0272
0273
0274
0275 const __m256 src = {0};
0276
0277 __m256i k = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
0278 __m256i k_sel = _mm256_set1_epi32(N_proc);
0279 __m256i k_master = _mm256_cmpgt_epi32(k_sel, k);
0280
0281 k = k_master;
0282 for (int i = 0; i < kSize; ++i, ++arr) {
0283 __m256 reg = _mm256_mask_i32gather_ps(src, (float*)arr, vi, (__m256)k, sizeof(U));
0284
0285 k = k_master;
0286 _mm256_maskstore_ps((float*)&fArray[i * N], k, reg);
0287 }
0288 }
0289
0290 #else
0291
0292 void slurpIn(const T* arr, int vi[N], const int N_proc = N) {
0293
0294 if (N_proc == N) {
0295 for (int i = 0; i < kSize; ++i) {
0296 for (int j = 0; j < N; ++j) {
0297 fArray[i * N + j] = *(arr + i + vi[j]);
0298 }
0299 }
0300 } else {
0301 for (int i = 0; i < kSize; ++i) {
0302 for (int j = 0; j < N_proc; ++j) {
0303 fArray[i * N + j] = *(arr + i + vi[j]);
0304 }
0305 }
0306 }
0307 }
0308
0309 #endif
0310
0311 void copyOut(idx_t n, T* arr) const {
0312 for (idx_t i = n; i < kTotSize; i += N) {
0313 *(arr++) = fArray[i];
0314 }
0315 }
0316
0317 Matriplex<T, 1, 1, N> ReduceFixedIJ(idx_t i, idx_t j) const {
0318 Matriplex<T, 1, 1, N> t;
0319 for (idx_t n = 0; n < N; ++n) {
0320 t[n] = constAt(n, i, j);
0321 }
0322 return t;
0323 }
0324 };
0325
0326 template <typename T, idx_t D1, idx_t D2, idx_t N>
0327 using MPlex = Matriplex<T, D1, D2, N>;
0328
0329
0330
0331
0332
0333 template <typename T, idx_t D1, idx_t D2, idx_t N>
0334 MPlex<T, D1, D2, N> operator+(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0335 MPlex<T, D1, D2, N> t = a;
0336 t += b;
0337 return t;
0338 }
0339
0340 template <typename T, idx_t D1, idx_t D2, idx_t N>
0341 MPlex<T, D1, D2, N> operator-(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0342 MPlex<T, D1, D2, N> t = a;
0343 t -= b;
0344 return t;
0345 }
0346
0347 template <typename T, idx_t D1, idx_t D2, idx_t N>
0348 MPlex<T, D1, D2, N> operator*(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0349 MPlex<T, D1, D2, N> t = a;
0350 t *= b;
0351 return t;
0352 }
0353
0354 template <typename T, idx_t D1, idx_t D2, idx_t N>
0355 MPlex<T, D1, D2, N> operator/(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0356 MPlex<T, D1, D2, N> t = a;
0357 t /= b;
0358 return t;
0359 }
0360
0361 template <typename T, idx_t D1, idx_t D2, idx_t N>
0362 MPlex<T, D1, D2, N> operator+(const MPlex<T, D1, D2, N>& a, T b) {
0363 MPlex<T, D1, D2, N> t = a;
0364 t += b;
0365 return t;
0366 }
0367
0368 template <typename T, idx_t D1, idx_t D2, idx_t N>
0369 MPlex<T, D1, D2, N> operator-(const MPlex<T, D1, D2, N>& a, T b) {
0370 MPlex<T, D1, D2, N> t = a;
0371 t -= b;
0372 return t;
0373 }
0374
0375 template <typename T, idx_t D1, idx_t D2, idx_t N>
0376 MPlex<T, D1, D2, N> operator*(const MPlex<T, D1, D2, N>& a, T b) {
0377 MPlex<T, D1, D2, N> t = a;
0378 t *= b;
0379 return t;
0380 }
0381
0382 template <typename T, idx_t D1, idx_t D2, idx_t N>
0383 MPlex<T, D1, D2, N> operator/(const MPlex<T, D1, D2, N>& a, T b) {
0384 MPlex<T, D1, D2, N> t = a;
0385 t /= b;
0386 return t;
0387 }
0388
0389 template <typename T, idx_t D1, idx_t D2, idx_t N>
0390 MPlex<T, D1, D2, N> operator+(T a, const MPlex<T, D1, D2, N>& b) {
0391 MPlex<T, D1, D2, N> t = a;
0392 t += b;
0393 return t;
0394 }
0395
0396 template <typename T, idx_t D1, idx_t D2, idx_t N>
0397 MPlex<T, D1, D2, N> operator-(T a, const MPlex<T, D1, D2, N>& b) {
0398 MPlex<T, D1, D2, N> t = a;
0399 t -= b;
0400 return t;
0401 }
0402
0403 template <typename T, idx_t D1, idx_t D2, idx_t N>
0404 MPlex<T, D1, D2, N> operator*(T a, const MPlex<T, D1, D2, N>& b) {
0405 MPlex<T, D1, D2, N> t = a;
0406 t *= b;
0407 return t;
0408 }
0409
0410 template <typename T, idx_t D1, idx_t D2, idx_t N>
0411 MPlex<T, D1, D2, N> operator/(T a, const MPlex<T, D1, D2, N>& b) {
0412 MPlex<T, D1, D2, N> t = a;
0413 t /= b;
0414 return t;
0415 }
0416
0417 template <typename T, idx_t D1, idx_t D2, idx_t N>
0418 MPlex<T, D1, D2, N> abs(const MPlex<T, D1, D2, N>& a) {
0419 MPlex<T, D1, D2, N> t;
0420 return t.abs(a);
0421 }
0422
0423 template <typename T, idx_t D1, idx_t D2, idx_t N>
0424 MPlex<T, D1, D2, N> sqrt(const MPlex<T, D1, D2, N>& a) {
0425 MPlex<T, D1, D2, N> t;
0426 return t.sqrt(a);
0427 }
0428
0429 template <typename T, idx_t D1, idx_t D2, idx_t N>
0430 MPlex<T, D1, D2, N> sqr(const MPlex<T, D1, D2, N>& a) {
0431 MPlex<T, D1, D2, N> t;
0432 return t.sqr(a);
0433 }
0434
0435 template <typename T, idx_t D1, idx_t D2, idx_t N>
0436 MPlex<T, D1, D2, N> hypot(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0437 MPlex<T, D1, D2, N> t;
0438 return t.hypot(a, b);
0439 }
0440
0441 template <typename T, idx_t D1, idx_t D2, idx_t N>
0442 MPlex<T, D1, D2, N> sin(const MPlex<T, D1, D2, N>& a) {
0443 MPlex<T, D1, D2, N> t;
0444 return t.sin(a);
0445 }
0446
0447 template <typename T, idx_t D1, idx_t D2, idx_t N>
0448 MPlex<T, D1, D2, N> cos(const MPlex<T, D1, D2, N>& a) {
0449 MPlex<T, D1, D2, N> t;
0450 return t.cos(a);
0451 }
0452
0453 template <typename T, idx_t D1, idx_t D2, idx_t N>
0454 void sincos(const MPlex<T, D1, D2, N>& a, MPlex<T, D1, D2, N>& s, MPlex<T, D1, D2, N>& c) {
0455 for (idx_t i = 0; i < a.kTotSize; ++i) {
0456 s.fArray[i] = std::sin(a.fArray[i]);
0457 c.fArray[i] = std::cos(a.fArray[i]);
0458 }
0459 }
0460
0461 template <typename T, idx_t D1, idx_t D2, idx_t N>
0462 MPlex<T, D1, D2, N> tan(const MPlex<T, D1, D2, N>& a) {
0463 MPlex<T, D1, D2, N> t;
0464 return t.tan(a);
0465 }
0466
0467 template <typename T, idx_t D1, idx_t D2, idx_t N>
0468 void min_max(const MPlex<T, D1, D2, N>& a,
0469 const MPlex<T, D1, D2, N>& b,
0470 MPlex<T, D1, D2, N>& min,
0471 MPlex<T, D1, D2, N>& max) {
0472 for (idx_t i = 0; i < a.kTotSize; ++i) {
0473 min.fArray[i] = std::min(a.fArray[i], b.fArray[i]);
0474 max.fArray[i] = std::max(a.fArray[i], b.fArray[i]);
0475 }
0476 }
0477
0478 template <typename T, idx_t D1, idx_t D2, idx_t N>
0479 MPlex<T, D1, D2, N> min(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0480 MPlex<T, D1, D2, N> t;
0481 for (idx_t i = 0; i < a.kTotSize; ++i) {
0482 t.fArray[i] = std::min(a.fArray[i], b.fArray[i]);
0483 }
0484 return t;
0485 }
0486
0487 template <typename T, idx_t D1, idx_t D2, idx_t N>
0488 MPlex<T, D1, D2, N> max(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0489 MPlex<T, D1, D2, N> t;
0490 for (idx_t i = 0; i < a.kTotSize; ++i) {
0491 t.fArray[i] = std::max(a.fArray[i], b.fArray[i]);
0492 }
0493 return t;
0494 }
0495
0496
0497
0498
0499
0500 template <typename T, idx_t D1, idx_t D2, idx_t D3, idx_t N>
0501 void multiplyGeneral(const MPlex<T, D1, D2, N>& A, const MPlex<T, D2, D3, N>& B, MPlex<T, D1, D3, N>& C) {
0502 for (idx_t i = 0; i < D1; ++i) {
0503 for (idx_t j = 0; j < D3; ++j) {
0504 const idx_t ijo = N * (i * D3 + j);
0505
0506 #pragma omp simd
0507 for (idx_t n = 0; n < N; ++n) {
0508 C.fArray[ijo + n] = 0;
0509 }
0510
0511 for (idx_t k = 0; k < D2; ++k) {
0512 const idx_t iko = N * (i * D2 + k);
0513 const idx_t kjo = N * (k * D3 + j);
0514
0515 #pragma omp simd
0516 for (idx_t n = 0; n < N; ++n) {
0517 C.fArray[ijo + n] += A.fArray[iko + n] * B.fArray[kjo + n];
0518 }
0519 }
0520 }
0521 }
0522 }
0523
0524
0525
0526 template <typename T, idx_t D, idx_t N>
0527 struct MultiplyCls {
0528 static void multiply(const MPlex<T, D, D, N>& A, const MPlex<T, D, D, N>& B, MPlex<T, D, D, N>& C) {
0529 throw std::runtime_error("general multiplication not supported, well, call multiplyGeneral()");
0530 }
0531 };
0532
0533 template <typename T, idx_t N>
0534 struct MultiplyCls<T, 3, N> {
0535 static void multiply(const MPlex<T, 3, 3, N>& A, const MPlex<T, 3, 3, N>& B, MPlex<T, 3, 3, N>& C) {
0536 const T* a = A.fArray;
0537 ASSUME_ALIGNED(a, 64);
0538 const T* b = B.fArray;
0539 ASSUME_ALIGNED(b, 64);
0540 T* c = C.fArray;
0541 ASSUME_ALIGNED(c, 64);
0542
0543 #pragma omp simd
0544 for (idx_t n = 0; n < N; ++n) {
0545 c[0 * N + n] = a[0 * N + n] * b[0 * N + n] + a[1 * N + n] * b[3 * N + n] + a[2 * N + n] * b[6 * N + n];
0546 c[1 * N + n] = a[0 * N + n] * b[1 * N + n] + a[1 * N + n] * b[4 * N + n] + a[2 * N + n] * b[7 * N + n];
0547 c[2 * N + n] = a[0 * N + n] * b[2 * N + n] + a[1 * N + n] * b[5 * N + n] + a[2 * N + n] * b[8 * N + n];
0548 c[3 * N + n] = a[3 * N + n] * b[0 * N + n] + a[4 * N + n] * b[3 * N + n] + a[5 * N + n] * b[6 * N + n];
0549 c[4 * N + n] = a[3 * N + n] * b[1 * N + n] + a[4 * N + n] * b[4 * N + n] + a[5 * N + n] * b[7 * N + n];
0550 c[5 * N + n] = a[3 * N + n] * b[2 * N + n] + a[4 * N + n] * b[5 * N + n] + a[5 * N + n] * b[8 * N + n];
0551 c[6 * N + n] = a[6 * N + n] * b[0 * N + n] + a[7 * N + n] * b[3 * N + n] + a[8 * N + n] * b[6 * N + n];
0552 c[7 * N + n] = a[6 * N + n] * b[1 * N + n] + a[7 * N + n] * b[4 * N + n] + a[8 * N + n] * b[7 * N + n];
0553 c[8 * N + n] = a[6 * N + n] * b[2 * N + n] + a[7 * N + n] * b[5 * N + n] + a[8 * N + n] * b[8 * N + n];
0554 }
0555 }
0556 };
0557
0558 template <typename T, idx_t N>
0559 struct MultiplyCls<T, 6, N> {
0560 static void multiply(const MPlex<T, 6, 6, N>& A, const MPlex<T, 6, 6, N>& B, MPlex<T, 6, 6, N>& C) {
0561 const T* a = A.fArray;
0562 ASSUME_ALIGNED(a, 64);
0563 const T* b = B.fArray;
0564 ASSUME_ALIGNED(b, 64);
0565 T* c = C.fArray;
0566 ASSUME_ALIGNED(c, 64);
0567 #pragma omp simd
0568 for (idx_t n = 0; n < N; ++n) {
0569 c[0 * N + n] = a[0 * N + n] * b[0 * N + n] + a[1 * N + n] * b[6 * N + n] + a[2 * N + n] * b[12 * N + n] +
0570 a[3 * N + n] * b[18 * N + n] + a[4 * N + n] * b[24 * N + n] + a[5 * N + n] * b[30 * N + n];
0571 c[1 * N + n] = a[0 * N + n] * b[1 * N + n] + a[1 * N + n] * b[7 * N + n] + a[2 * N + n] * b[13 * N + n] +
0572 a[3 * N + n] * b[19 * N + n] + a[4 * N + n] * b[25 * N + n] + a[5 * N + n] * b[31 * N + n];
0573 c[2 * N + n] = a[0 * N + n] * b[2 * N + n] + a[1 * N + n] * b[8 * N + n] + a[2 * N + n] * b[14 * N + n] +
0574 a[3 * N + n] * b[20 * N + n] + a[4 * N + n] * b[26 * N + n] + a[5 * N + n] * b[32 * N + n];
0575 c[3 * N + n] = a[0 * N + n] * b[3 * N + n] + a[1 * N + n] * b[9 * N + n] + a[2 * N + n] * b[15 * N + n] +
0576 a[3 * N + n] * b[21 * N + n] + a[4 * N + n] * b[27 * N + n] + a[5 * N + n] * b[33 * N + n];
0577 c[4 * N + n] = a[0 * N + n] * b[4 * N + n] + a[1 * N + n] * b[10 * N + n] + a[2 * N + n] * b[16 * N + n] +
0578 a[3 * N + n] * b[22 * N + n] + a[4 * N + n] * b[28 * N + n] + a[5 * N + n] * b[34 * N + n];
0579 c[5 * N + n] = a[0 * N + n] * b[5 * N + n] + a[1 * N + n] * b[11 * N + n] + a[2 * N + n] * b[17 * N + n] +
0580 a[3 * N + n] * b[23 * N + n] + a[4 * N + n] * b[29 * N + n] + a[5 * N + n] * b[35 * N + n];
0581 c[6 * N + n] = a[6 * N + n] * b[0 * N + n] + a[7 * N + n] * b[6 * N + n] + a[8 * N + n] * b[12 * N + n] +
0582 a[9 * N + n] * b[18 * N + n] + a[10 * N + n] * b[24 * N + n] + a[11 * N + n] * b[30 * N + n];
0583 c[7 * N + n] = a[6 * N + n] * b[1 * N + n] + a[7 * N + n] * b[7 * N + n] + a[8 * N + n] * b[13 * N + n] +
0584 a[9 * N + n] * b[19 * N + n] + a[10 * N + n] * b[25 * N + n] + a[11 * N + n] * b[31 * N + n];
0585 c[8 * N + n] = a[6 * N + n] * b[2 * N + n] + a[7 * N + n] * b[8 * N + n] + a[8 * N + n] * b[14 * N + n] +
0586 a[9 * N + n] * b[20 * N + n] + a[10 * N + n] * b[26 * N + n] + a[11 * N + n] * b[32 * N + n];
0587 c[9 * N + n] = a[6 * N + n] * b[3 * N + n] + a[7 * N + n] * b[9 * N + n] + a[8 * N + n] * b[15 * N + n] +
0588 a[9 * N + n] * b[21 * N + n] + a[10 * N + n] * b[27 * N + n] + a[11 * N + n] * b[33 * N + n];
0589 c[10 * N + n] = a[6 * N + n] * b[4 * N + n] + a[7 * N + n] * b[10 * N + n] + a[8 * N + n] * b[16 * N + n] +
0590 a[9 * N + n] * b[22 * N + n] + a[10 * N + n] * b[28 * N + n] + a[11 * N + n] * b[34 * N + n];
0591 c[11 * N + n] = a[6 * N + n] * b[5 * N + n] + a[7 * N + n] * b[11 * N + n] + a[8 * N + n] * b[17 * N + n] +
0592 a[9 * N + n] * b[23 * N + n] + a[10 * N + n] * b[29 * N + n] + a[11 * N + n] * b[35 * N + n];
0593 c[12 * N + n] = a[12 * N + n] * b[0 * N + n] + a[13 * N + n] * b[6 * N + n] + a[14 * N + n] * b[12 * N + n] +
0594 a[15 * N + n] * b[18 * N + n] + a[16 * N + n] * b[24 * N + n] + a[17 * N + n] * b[30 * N + n];
0595 c[13 * N + n] = a[12 * N + n] * b[1 * N + n] + a[13 * N + n] * b[7 * N + n] + a[14 * N + n] * b[13 * N + n] +
0596 a[15 * N + n] * b[19 * N + n] + a[16 * N + n] * b[25 * N + n] + a[17 * N + n] * b[31 * N + n];
0597 c[14 * N + n] = a[12 * N + n] * b[2 * N + n] + a[13 * N + n] * b[8 * N + n] + a[14 * N + n] * b[14 * N + n] +
0598 a[15 * N + n] * b[20 * N + n] + a[16 * N + n] * b[26 * N + n] + a[17 * N + n] * b[32 * N + n];
0599 c[15 * N + n] = a[12 * N + n] * b[3 * N + n] + a[13 * N + n] * b[9 * N + n] + a[14 * N + n] * b[15 * N + n] +
0600 a[15 * N + n] * b[21 * N + n] + a[16 * N + n] * b[27 * N + n] + a[17 * N + n] * b[33 * N + n];
0601 c[16 * N + n] = a[12 * N + n] * b[4 * N + n] + a[13 * N + n] * b[10 * N + n] + a[14 * N + n] * b[16 * N + n] +
0602 a[15 * N + n] * b[22 * N + n] + a[16 * N + n] * b[28 * N + n] + a[17 * N + n] * b[34 * N + n];
0603 c[17 * N + n] = a[12 * N + n] * b[5 * N + n] + a[13 * N + n] * b[11 * N + n] + a[14 * N + n] * b[17 * N + n] +
0604 a[15 * N + n] * b[23 * N + n] + a[16 * N + n] * b[29 * N + n] + a[17 * N + n] * b[35 * N + n];
0605 c[18 * N + n] = a[18 * N + n] * b[0 * N + n] + a[19 * N + n] * b[6 * N + n] + a[20 * N + n] * b[12 * N + n] +
0606 a[21 * N + n] * b[18 * N + n] + a[22 * N + n] * b[24 * N + n] + a[23 * N + n] * b[30 * N + n];
0607 c[19 * N + n] = a[18 * N + n] * b[1 * N + n] + a[19 * N + n] * b[7 * N + n] + a[20 * N + n] * b[13 * N + n] +
0608 a[21 * N + n] * b[19 * N + n] + a[22 * N + n] * b[25 * N + n] + a[23 * N + n] * b[31 * N + n];
0609 c[20 * N + n] = a[18 * N + n] * b[2 * N + n] + a[19 * N + n] * b[8 * N + n] + a[20 * N + n] * b[14 * N + n] +
0610 a[21 * N + n] * b[20 * N + n] + a[22 * N + n] * b[26 * N + n] + a[23 * N + n] * b[32 * N + n];
0611 c[21 * N + n] = a[18 * N + n] * b[3 * N + n] + a[19 * N + n] * b[9 * N + n] + a[20 * N + n] * b[15 * N + n] +
0612 a[21 * N + n] * b[21 * N + n] + a[22 * N + n] * b[27 * N + n] + a[23 * N + n] * b[33 * N + n];
0613 c[22 * N + n] = a[18 * N + n] * b[4 * N + n] + a[19 * N + n] * b[10 * N + n] + a[20 * N + n] * b[16 * N + n] +
0614 a[21 * N + n] * b[22 * N + n] + a[22 * N + n] * b[28 * N + n] + a[23 * N + n] * b[34 * N + n];
0615 c[23 * N + n] = a[18 * N + n] * b[5 * N + n] + a[19 * N + n] * b[11 * N + n] + a[20 * N + n] * b[17 * N + n] +
0616 a[21 * N + n] * b[23 * N + n] + a[22 * N + n] * b[29 * N + n] + a[23 * N + n] * b[35 * N + n];
0617 c[24 * N + n] = a[24 * N + n] * b[0 * N + n] + a[25 * N + n] * b[6 * N + n] + a[26 * N + n] * b[12 * N + n] +
0618 a[27 * N + n] * b[18 * N + n] + a[28 * N + n] * b[24 * N + n] + a[29 * N + n] * b[30 * N + n];
0619 c[25 * N + n] = a[24 * N + n] * b[1 * N + n] + a[25 * N + n] * b[7 * N + n] + a[26 * N + n] * b[13 * N + n] +
0620 a[27 * N + n] * b[19 * N + n] + a[28 * N + n] * b[25 * N + n] + a[29 * N + n] * b[31 * N + n];
0621 c[26 * N + n] = a[24 * N + n] * b[2 * N + n] + a[25 * N + n] * b[8 * N + n] + a[26 * N + n] * b[14 * N + n] +
0622 a[27 * N + n] * b[20 * N + n] + a[28 * N + n] * b[26 * N + n] + a[29 * N + n] * b[32 * N + n];
0623 c[27 * N + n] = a[24 * N + n] * b[3 * N + n] + a[25 * N + n] * b[9 * N + n] + a[26 * N + n] * b[15 * N + n] +
0624 a[27 * N + n] * b[21 * N + n] + a[28 * N + n] * b[27 * N + n] + a[29 * N + n] * b[33 * N + n];
0625 c[28 * N + n] = a[24 * N + n] * b[4 * N + n] + a[25 * N + n] * b[10 * N + n] + a[26 * N + n] * b[16 * N + n] +
0626 a[27 * N + n] * b[22 * N + n] + a[28 * N + n] * b[28 * N + n] + a[29 * N + n] * b[34 * N + n];
0627 c[29 * N + n] = a[24 * N + n] * b[5 * N + n] + a[25 * N + n] * b[11 * N + n] + a[26 * N + n] * b[17 * N + n] +
0628 a[27 * N + n] * b[23 * N + n] + a[28 * N + n] * b[29 * N + n] + a[29 * N + n] * b[35 * N + n];
0629 c[30 * N + n] = a[30 * N + n] * b[0 * N + n] + a[31 * N + n] * b[6 * N + n] + a[32 * N + n] * b[12 * N + n] +
0630 a[33 * N + n] * b[18 * N + n] + a[34 * N + n] * b[24 * N + n] + a[35 * N + n] * b[30 * N + n];
0631 c[31 * N + n] = a[30 * N + n] * b[1 * N + n] + a[31 * N + n] * b[7 * N + n] + a[32 * N + n] * b[13 * N + n] +
0632 a[33 * N + n] * b[19 * N + n] + a[34 * N + n] * b[25 * N + n] + a[35 * N + n] * b[31 * N + n];
0633 c[32 * N + n] = a[30 * N + n] * b[2 * N + n] + a[31 * N + n] * b[8 * N + n] + a[32 * N + n] * b[14 * N + n] +
0634 a[33 * N + n] * b[20 * N + n] + a[34 * N + n] * b[26 * N + n] + a[35 * N + n] * b[32 * N + n];
0635 c[33 * N + n] = a[30 * N + n] * b[3 * N + n] + a[31 * N + n] * b[9 * N + n] + a[32 * N + n] * b[15 * N + n] +
0636 a[33 * N + n] * b[21 * N + n] + a[34 * N + n] * b[27 * N + n] + a[35 * N + n] * b[33 * N + n];
0637 c[34 * N + n] = a[30 * N + n] * b[4 * N + n] + a[31 * N + n] * b[10 * N + n] + a[32 * N + n] * b[16 * N + n] +
0638 a[33 * N + n] * b[22 * N + n] + a[34 * N + n] * b[28 * N + n] + a[35 * N + n] * b[34 * N + n];
0639 c[35 * N + n] = a[30 * N + n] * b[5 * N + n] + a[31 * N + n] * b[11 * N + n] + a[32 * N + n] * b[17 * N + n] +
0640 a[33 * N + n] * b[23 * N + n] + a[34 * N + n] * b[29 * N + n] + a[35 * N + n] * b[35 * N + n];
0641 }
0642 }
0643 };
0644
0645 template <typename T, idx_t D, idx_t N>
0646 void multiply(const MPlex<T, D, D, N>& A, const MPlex<T, D, D, N>& B, MPlex<T, D, D, N>& C) {
0647 #ifdef DEBUG
0648 printf("Multipl %d %d\n", D, N);
0649 #endif
0650
0651 MultiplyCls<T, D, N>::multiply(A, B, C);
0652 }
0653
0654
0655
0656
0657
0658 template <typename T, idx_t D, idx_t N>
0659 struct CramerInverter {
0660 static void invert(MPlex<T, D, D, N>& A, double* determ = nullptr) {
0661 throw std::runtime_error("general cramer inversion not supported");
0662 }
0663 };
0664
0665 template <typename T, idx_t N>
0666 struct CramerInverter<T, 2, N> {
0667 static void invert(MPlex<T, 2, 2, N>& A, double* determ = nullptr) {
0668 typedef T TT;
0669
0670 T* a = A.fArray;
0671 ASSUME_ALIGNED(a, 64);
0672
0673 #pragma omp simd
0674 for (idx_t n = 0; n < N; ++n) {
0675
0676 const double det = (double)a[0 * N + n] * a[3 * N + n] - (double)a[2 * N + n] * a[1 * N + n];
0677 if (determ)
0678 determ[n] = det;
0679
0680 const TT s = TT(1) / det;
0681 const TT tmp = s * a[3 * N + n];
0682 a[1 * N + n] *= -s;
0683 a[2 * N + n] *= -s;
0684 a[3 * N + n] = s * a[0 * N + n];
0685 a[0 * N + n] = tmp;
0686 }
0687 }
0688 };
0689
0690 template <typename T, idx_t N>
0691 struct CramerInverter<T, 3, N> {
0692 static void invert(MPlex<T, 3, 3, N>& A, double* determ = nullptr) {
0693 typedef T TT;
0694
0695 T* a = A.fArray;
0696 ASSUME_ALIGNED(a, 64);
0697
0698 #pragma omp simd
0699 for (idx_t n = 0; n < N; ++n) {
0700 const TT c00 = a[4 * N + n] * a[8 * N + n] - a[5 * N + n] * a[7 * N + n];
0701 const TT c01 = a[5 * N + n] * a[6 * N + n] - a[3 * N + n] * a[8 * N + n];
0702 const TT c02 = a[3 * N + n] * a[7 * N + n] - a[4 * N + n] * a[6 * N + n];
0703 const TT c10 = a[7 * N + n] * a[2 * N + n] - a[8 * N + n] * a[1 * N + n];
0704 const TT c11 = a[8 * N + n] * a[0 * N + n] - a[6 * N + n] * a[2 * N + n];
0705 const TT c12 = a[6 * N + n] * a[1 * N + n] - a[7 * N + n] * a[0 * N + n];
0706 const TT c20 = a[1 * N + n] * a[5 * N + n] - a[2 * N + n] * a[4 * N + n];
0707 const TT c21 = a[2 * N + n] * a[3 * N + n] - a[0 * N + n] * a[5 * N + n];
0708 const TT c22 = a[0 * N + n] * a[4 * N + n] - a[1 * N + n] * a[3 * N + n];
0709
0710
0711 const double det = (double)a[0 * N + n] * c00 + (double)a[1 * N + n] * c01 + (double)a[2 * N + n] * c02;
0712 if (determ)
0713 determ[n] = det;
0714
0715 const TT s = TT(1) / det;
0716 a[0 * N + n] = s * c00;
0717 a[1 * N + n] = s * c10;
0718 a[2 * N + n] = s * c20;
0719 a[3 * N + n] = s * c01;
0720 a[4 * N + n] = s * c11;
0721 a[5 * N + n] = s * c21;
0722 a[6 * N + n] = s * c02;
0723 a[7 * N + n] = s * c12;
0724 a[8 * N + n] = s * c22;
0725 }
0726 }
0727 };
0728
0729 template <typename T, idx_t D, idx_t N>
0730 void invertCramer(MPlex<T, D, D, N>& A, double* determ = nullptr) {
0731 CramerInverter<T, D, N>::invert(A, determ);
0732 }
0733
0734
0735
0736
0737
0738 template <typename T, idx_t D, idx_t N>
0739 struct CholeskyInverter {
0740 static void invert(MPlex<T, D, D, N>& A) { throw std::runtime_error("general cholesky inversion not supported"); }
0741 };
0742
0743 template <typename T, idx_t N>
0744 struct CholeskyInverter<T, 3, N> {
0745
0746
0747 static void invert(MPlex<T, 3, 3, N>& A) {
0748 typedef T TT;
0749
0750 T* a = A.fArray;
0751 ASSUME_ALIGNED(a, 64);
0752
0753 #pragma omp simd
0754 for (idx_t n = 0; n < N; ++n) {
0755 TT l0 = std::sqrt(T(1) / a[0 * N + n]);
0756 TT l1 = a[3 * N + n] * l0;
0757 TT l2 = a[4 * N + n] - l1 * l1;
0758 l2 = std::sqrt(T(1) / l2);
0759 TT l3 = a[6 * N + n] * l0;
0760 TT l4 = (a[7 * N + n] - l1 * l3) * l2;
0761 TT l5 = a[8 * N + n] - (l3 * l3 + l4 * l4);
0762 l5 = std::sqrt(T(1) / l5);
0763
0764
0765
0766 l3 = (l1 * l4 * l2 - l3) * l0 * l5;
0767 l1 = -l1 * l0 * l2;
0768 l4 = -l4 * l2 * l5;
0769
0770 a[0 * N + n] = l3 * l3 + l1 * l1 + l0 * l0;
0771 a[1 * N + n] = a[3 * N + n] = l3 * l4 + l1 * l2;
0772 a[4 * N + n] = l4 * l4 + l2 * l2;
0773 a[2 * N + n] = a[6 * N + n] = l3 * l5;
0774 a[5 * N + n] = a[7 * N + n] = l4 * l5;
0775 a[8 * N + n] = l5 * l5;
0776
0777
0778
0779 }
0780 }
0781 };
0782
0783 template <typename T, idx_t D, idx_t N>
0784 void invertCholesky(MPlex<T, D, D, N>& A) {
0785 CholeskyInverter<T, D, N>::invert(A);
0786 }
0787
0788 }
0789
0790 #endif