File indexing completed on 2024-10-17 22:59:03
0001 #ifndef RecoTracker_MkFitCore_src_Matriplex_Matriplex_h
0002 #define RecoTracker_MkFitCore_src_Matriplex_Matriplex_h
0003
0004 #include "MatriplexCommon.h"
0005
0006 #ifdef MPLEX_VDT
0007
0008 #ifdef MPLEX_VDT_USE_STD
0009
0010 namespace std {
0011 template <typename T>
0012 T isqrt(T x) {
0013 return T(1.0) / std::sqrt(x);
0014 }
0015 template <typename T>
0016 void sincos(T a, T& s, T& c) {
0017 s = std::sin(a);
0018 c = std::cos(a);
0019 }
0020 }
0021 #else
0022 #include "vdt/sqrt.h"
0023 #include "vdt/sin.h"
0024 #include "vdt/cos.h"
0025 #include "vdt/tan.h"
0026 #include "vdt/atan2.h"
0027 #endif
0028 #endif
0029
0030 namespace Matriplex {
0031
0032
0033
0034 template <typename T, idx_t D1, idx_t D2, idx_t N>
0035 class __attribute__((aligned(MPLEX_ALIGN))) Matriplex {
0036 public:
0037 typedef T value_type;
0038
0039
0040 static constexpr int kRows = D1;
0041
0042 static constexpr int kCols = D2;
0043
0044 static constexpr int kSize = D1 * D2;
0045
0046 static constexpr int kTotSize = N * kSize;
0047
0048 T fArray[kTotSize];
0049
0050 Matriplex() {}
0051 Matriplex(T v) { setVal(v); }
0052
0053 idx_t plexSize() const { return N; }
0054
0055 void setVal(T v) {
0056 for (idx_t i = 0; i < kTotSize; ++i) {
0057 fArray[i] = v;
0058 }
0059 }
0060
0061 void add(const Matriplex& v) {
0062 for (idx_t i = 0; i < kTotSize; ++i) {
0063 fArray[i] += v.fArray[i];
0064 }
0065 }
0066
0067 void scale(T scale) {
0068 for (idx_t i = 0; i < kTotSize; ++i) {
0069 fArray[i] *= scale;
0070 }
0071 }
0072
0073 Matriplex& negate() {
0074 for (idx_t i = 0; i < kTotSize; ++i) {
0075 fArray[i] = -fArray[i];
0076 }
0077 return *this;
0078 }
0079
0080 template <typename TT>
0081 Matriplex& negate_if_ltz(const Matriplex<TT, D1, D2, N>& sign) {
0082 for (idx_t i = 0; i < kTotSize; ++i) {
0083 if (sign.fArray[i] < 0)
0084 fArray[i] = -fArray[i];
0085 }
0086 return *this;
0087 }
0088
0089 T operator[](idx_t xx) const { return fArray[xx]; }
0090 T& operator[](idx_t xx) { return fArray[xx]; }
0091
0092 const T& constAt(idx_t n, idx_t i, idx_t j) const { return fArray[(i * D2 + j) * N + n]; }
0093
0094 T& At(idx_t n, idx_t i, idx_t j) { return fArray[(i * D2 + j) * N + n]; }
0095
0096 T& operator()(idx_t n, idx_t i, idx_t j) { return fArray[(i * D2 + j) * N + n]; }
0097 const T& operator()(idx_t n, idx_t i, idx_t j) const { return fArray[(i * D2 + j) * N + n]; }
0098
0099
0100
0101 using QReduced = Matriplex<T, 1, 1, N>;
0102
0103 QReduced ReduceFixedIJ(idx_t i, idx_t j) const {
0104 QReduced t;
0105 for (idx_t n = 0; n < N; ++n) {
0106 t[n] = constAt(n, i, j);
0107 }
0108 return t;
0109 }
0110 QReduced rij(idx_t i, idx_t j) const { return ReduceFixedIJ(i, j); }
0111 QReduced operator()(idx_t i, idx_t j) const { return ReduceFixedIJ(i, j); }
0112
0113 struct QAssigner {
0114 Matriplex& m_matriplex;
0115 const int m_i, m_j;
0116
0117 QAssigner(Matriplex& m, int i, int j) : m_matriplex(m), m_i(i), m_j(j) {}
0118 Matriplex& operator=(const QReduced& qvec) {
0119 for (idx_t n = 0; n < N; ++n) {
0120 m_matriplex(n, m_i, m_j) = qvec[n];
0121 }
0122 return m_matriplex;
0123 }
0124 Matriplex& operator=(T qscalar) {
0125 for (idx_t n = 0; n < N; ++n) {
0126 m_matriplex(n, m_i, m_j) = qscalar;
0127 }
0128 return m_matriplex;
0129 }
0130 };
0131
0132 QAssigner AssignFixedIJ(idx_t i, idx_t j) { return QAssigner(*this, i, j); }
0133 QAssigner aij(idx_t i, idx_t j) { return AssignFixedIJ(i, j); }
0134
0135
0136
0137 Matriplex& operator=(T t) {
0138 for (idx_t i = 0; i < kTotSize; ++i)
0139 fArray[i] = t;
0140 return *this;
0141 }
0142
0143 Matriplex& operator+=(T t) {
0144 for (idx_t i = 0; i < kTotSize; ++i)
0145 fArray[i] += t;
0146 return *this;
0147 }
0148
0149 Matriplex& operator-=(T t) {
0150 for (idx_t i = 0; i < kTotSize; ++i)
0151 fArray[i] -= t;
0152 return *this;
0153 }
0154
0155 Matriplex& operator*=(T t) {
0156 for (idx_t i = 0; i < kTotSize; ++i)
0157 fArray[i] *= t;
0158 return *this;
0159 }
0160
0161 Matriplex& operator/=(T t) {
0162 for (idx_t i = 0; i < kTotSize; ++i)
0163 fArray[i] /= t;
0164 return *this;
0165 }
0166
0167 Matriplex& operator+=(const Matriplex& a) {
0168 for (idx_t i = 0; i < kTotSize; ++i)
0169 fArray[i] += a.fArray[i];
0170 return *this;
0171 }
0172
0173 Matriplex& operator-=(const Matriplex& a) {
0174 for (idx_t i = 0; i < kTotSize; ++i)
0175 fArray[i] -= a.fArray[i];
0176 return *this;
0177 }
0178
0179 Matriplex& operator*=(const Matriplex& a) {
0180 for (idx_t i = 0; i < kTotSize; ++i)
0181 fArray[i] *= a.fArray[i];
0182 return *this;
0183 }
0184
0185 Matriplex& operator/=(const Matriplex& a) {
0186 for (idx_t i = 0; i < kTotSize; ++i)
0187 fArray[i] /= a.fArray[i];
0188 return *this;
0189 }
0190
0191 Matriplex operator-() {
0192 Matriplex t;
0193 for (idx_t i = 0; i < kTotSize; ++i)
0194 t.fArray[i] = -fArray[i];
0195 return t;
0196 }
0197
0198 Matriplex& abs(const Matriplex& a) {
0199 for (idx_t i = 0; i < kTotSize; ++i)
0200 fArray[i] = std::abs(a.fArray[i]);
0201 return *this;
0202 }
0203 Matriplex& abs() {
0204 for (idx_t i = 0; i < kTotSize; ++i)
0205 fArray[i] = std::abs(fArray[i]);
0206 return *this;
0207 }
0208
0209 Matriplex& sqr(const Matriplex& a) {
0210 for (idx_t i = 0; i < kTotSize; ++i)
0211 fArray[i] = a.fArray[i] * a.fArray[i];
0212 return *this;
0213 }
0214 Matriplex& sqr() {
0215 for (idx_t i = 0; i < kTotSize; ++i)
0216 fArray[i] = fArray[i] * fArray[i];
0217 return *this;
0218 }
0219
0220
0221
0222
0223 Matriplex& sqrt(const Matriplex& a) {
0224 for (idx_t i = 0; i < kTotSize; ++i)
0225 fArray[i] = std::sqrt(a.fArray[i]);
0226 return *this;
0227 }
0228 Matriplex& sqrt() {
0229 for (idx_t i = 0; i < kTotSize; ++i)
0230 fArray[i] = std::sqrt(fArray[i]);
0231 return *this;
0232 }
0233
0234 Matriplex& hypot(const Matriplex& a, const Matriplex& b) {
0235 for (idx_t i = 0; i < kTotSize; ++i) {
0236 fArray[i] = a.fArray[i] * a.fArray[i] + b.fArray[i] * b.fArray[i];
0237 }
0238 return sqrt();
0239 }
0240
0241 Matriplex& sin(const Matriplex& a) {
0242 for (idx_t i = 0; i < kTotSize; ++i)
0243 fArray[i] = std::sin(a.fArray[i]);
0244 return *this;
0245 }
0246 Matriplex& sin() {
0247 for (idx_t i = 0; i < kTotSize; ++i)
0248 fArray[i] = std::sin(fArray[i]);
0249 return *this;
0250 }
0251
0252 Matriplex& cos(const Matriplex& a) {
0253 for (idx_t i = 0; i < kTotSize; ++i)
0254 fArray[i] = std::cos(a.fArray[i]);
0255 return *this;
0256 }
0257 Matriplex& cos() {
0258 for (idx_t i = 0; i < kTotSize; ++i)
0259 fArray[i] = std::cos(fArray[i]);
0260 return *this;
0261 }
0262
0263 Matriplex& tan(const Matriplex& a) {
0264 for (idx_t i = 0; i < kTotSize; ++i)
0265 fArray[i] = std::tan(a.fArray[i]);
0266 return *this;
0267 }
0268 Matriplex& tan() {
0269 for (idx_t i = 0; i < kTotSize; ++i)
0270 fArray[i] = std::tan(fArray[i]);
0271 return *this;
0272 }
0273
0274 Matriplex& atan2(const Matriplex& y, const Matriplex& x) {
0275 for (idx_t i = 0; i < kTotSize; ++i)
0276 fArray[i] = std::atan2(y.fArray[i], x.fArray[i]);
0277 return *this;
0278 }
0279
0280
0281
0282
0283 #ifdef MPLEX_VDT
0284
0285 #define ASS fArray[i] =
0286 #define ARR fArray[i]
0287 #define A_ARR a.fArray[i]
0288
0289 #ifdef MPLEX_VDT_USE_STD
0290 #define VDT_INVOKE(_ass_, _func_, ...) \
0291 for (idx_t i = 0; i < kTotSize; ++i) \
0292 _ass_ std::_func_(__VA_ARGS__);
0293 #else
0294 #define VDT_INVOKE(_ass_, _func_, ...) \
0295 for (idx_t i = 0; i < kTotSize; ++i) \
0296 if constexpr (std::is_same<T, float>()) \
0297 _ass_ vdt::fast_##_func_##f(__VA_ARGS__); \
0298 else \
0299 _ass_ vdt::fast_##_func_(__VA_ARGS__);
0300 #endif
0301
0302 Matriplex& fast_isqrt(const Matriplex& a) {
0303 VDT_INVOKE(ASS, isqrt, A_ARR);
0304 return *this;
0305 }
0306 Matriplex& fast_isqrt() {
0307 VDT_INVOKE(ASS, isqrt, ARR);
0308 return *this;
0309 }
0310
0311 Matriplex& fast_sin(const Matriplex& a) {
0312 VDT_INVOKE(ASS, sin, A_ARR);
0313 return *this;
0314 }
0315 Matriplex& fast_sin() {
0316 VDT_INVOKE(ASS, sin, ARR);
0317 return *this;
0318 }
0319
0320 Matriplex& fast_cos(const Matriplex& a) {
0321 VDT_INVOKE(ASS, cos, A_ARR);
0322 return *this;
0323 }
0324 Matriplex& fast_cos() {
0325 VDT_INVOKE(ASS, cos, ARR);
0326 return *this;
0327 }
0328
0329 void fast_sincos(Matriplex& s, Matriplex& c) const { VDT_INVOKE(, sincos, ARR, s.fArray[i], c.fArray[i]); }
0330
0331 Matriplex& fast_tan(const Matriplex& a) {
0332 VDT_INVOKE(ASS, tan, A_ARR);
0333 return *this;
0334 }
0335 Matriplex& fast_tan() {
0336 VDT_INVOKE(ASS, tan, ARR);
0337 return *this;
0338 }
0339
0340 Matriplex& fast_atan2(const Matriplex& y, const Matriplex& x) {
0341 VDT_INVOKE(ASS, atan2, y.fArray[i], x.fArray[i]);
0342 return *this;
0343 }
0344
0345 #undef VDT_INVOKE
0346
0347 #undef ASS
0348 #undef ARR
0349 #undef A_ARR
0350 #endif
0351
0352 void sincos4(Matriplex& s, Matriplex& c) const {
0353 for (idx_t i = 0; i < kTotSize; ++i)
0354 internal::sincos4(fArray[i], s.fArray[i], c.fArray[i]);
0355 }
0356
0357
0358
0359 void copySlot(idx_t n, const Matriplex& m) {
0360 for (idx_t i = n; i < kTotSize; i += N) {
0361 fArray[i] = m.fArray[i];
0362 }
0363 }
0364
0365 void copyIn(idx_t n, const T* arr) {
0366 for (idx_t i = n; i < kTotSize; i += N) {
0367 fArray[i] = *(arr++);
0368 }
0369 }
0370
0371 void copyIn(idx_t n, const Matriplex& m, idx_t in) {
0372 for (idx_t i = n; i < kTotSize; i += N, in += N) {
0373 fArray[i] = m[in];
0374 }
0375 }
0376
0377 void copy(idx_t n, idx_t in) {
0378 for (idx_t i = n; i < kTotSize; i += N, in += N) {
0379 fArray[i] = fArray[in];
0380 }
0381 }
0382
0383 #if defined(AVX512_INTRINSICS)
0384
0385 template <typename U>
0386 void slurpIn(const T* arr, __m512i& vi, const U&, const int N_proc = N) {
0387
0388
0389 const __m512 src = {0};
0390 const __mmask16 k = N_proc == N ? -1 : (1 << N_proc) - 1;
0391
0392 for (int i = 0; i < kSize; ++i, ++arr) {
0393
0394
0395 __m512 reg = _mm512_mask_i32gather_ps(src, k, vi, arr, sizeof(U));
0396 _mm512_mask_store_ps(&fArray[i * N], k, reg);
0397 }
0398 }
0399
0400
0401
0402 void ChewIn(const char* arr, int off, int vi[N], const char* tmp, __m512i& ui) {
0403
0404
0405 for (int i = 0; i < N; ++i) {
0406 __m512 reg = _mm512_load_ps(arr + vi[i]);
0407 _mm512_store_ps((void*)(tmp + 64 * i), reg);
0408 }
0409
0410 for (int i = 0; i < kSize; ++i) {
0411 __m512 reg = _mm512_i32gather_ps(ui, tmp + off + i * sizeof(T), 1);
0412 _mm512_store_ps(&fArray[i * N], reg);
0413 }
0414 }
0415
0416 void Contaginate(const char* arr, int vi[N], const char* tmp) {
0417
0418
0419 for (int i = 0; i < N; ++i) {
0420 __m512 reg = _mm512_load_ps(arr + vi[i]);
0421 _mm512_store_ps((void*)(tmp + 64 * i), reg);
0422 }
0423 }
0424
0425 void Plexify(const char* tmp, __m512i& ui) {
0426 for (int i = 0; i < kSize; ++i) {
0427 __m512 reg = _mm512_i32gather_ps(ui, tmp + i * sizeof(T), 1);
0428 _mm512_store_ps(&fArray[i * N], reg);
0429 }
0430 }
0431
0432 #elif defined(AVX2_INTRINSICS)
0433
0434 template <typename U>
0435 void slurpIn(const T* arr, __m256i& vi, const U&, const int N_proc = N) {
0436
0437
0438
0439 const __m256 src = {0};
0440
0441 __m256i k = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
0442 __m256i k_sel = _mm256_set1_epi32(N_proc);
0443 __m256i k_master = _mm256_cmpgt_epi32(k_sel, k);
0444
0445 k = k_master;
0446 for (int i = 0; i < kSize; ++i, ++arr) {
0447 __m256 reg = _mm256_mask_i32gather_ps(src, (float*)arr, vi, (__m256)k, sizeof(U));
0448
0449 k = k_master;
0450 _mm256_maskstore_ps((float*)&fArray[i * N], k, reg);
0451 }
0452 }
0453
0454 #else
0455
0456 void slurpIn(const T* arr, int vi[N], const int N_proc = N) {
0457
0458 if (N_proc == N) {
0459 for (int i = 0; i < kSize; ++i) {
0460 for (int j = 0; j < N; ++j) {
0461 fArray[i * N + j] = *(arr + i + vi[j]);
0462 }
0463 }
0464 } else {
0465 for (int i = 0; i < kSize; ++i) {
0466 for (int j = 0; j < N_proc; ++j) {
0467 fArray[i * N + j] = *(arr + i + vi[j]);
0468 }
0469 }
0470 }
0471 }
0472
0473 #endif
0474
0475 void copyOut(idx_t n, T* arr) const {
0476 for (idx_t i = n; i < kTotSize; i += N) {
0477 *(arr++) = fArray[i];
0478 }
0479 }
0480 };
0481
0482 template <typename T, idx_t D1, idx_t D2, idx_t N>
0483 using MPlex = Matriplex<T, D1, D2, N>;
0484
0485
0486
0487
0488
0489 template <typename T, idx_t D1, idx_t D2, idx_t N>
0490 MPlex<T, D1, D2, N> operator-(const MPlex<T, D1, D2, N>& a) {
0491 MPlex<T, D1, D2, N> t = a;
0492 t.negate();
0493 return t;
0494 }
0495
0496 template <typename T, idx_t D1, idx_t D2, idx_t N>
0497 MPlex<T, D1, D2, N> negate(const MPlex<T, D1, D2, N>& a) {
0498 MPlex<T, D1, D2, N> t = a;
0499 t.negate();
0500 return t;
0501 }
0502
0503 template <typename T, typename TT, idx_t D1, idx_t D2, idx_t N>
0504 MPlex<T, D1, D2, N> negate_if_ltz(const MPlex<T, D1, D2, N>& a, const MPlex<TT, D1, D2, N>& sign) {
0505 MPlex<T, D1, D2, N> t = a;
0506 t.negate_if_ltz(sign);
0507 return t;
0508 }
0509
0510 template <typename T, idx_t D1, idx_t D2, idx_t N>
0511 MPlex<T, D1, D2, N> operator+(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0512 MPlex<T, D1, D2, N> t = a;
0513 t += b;
0514 return t;
0515 }
0516
0517 template <typename T, idx_t D1, idx_t D2, idx_t N>
0518 MPlex<T, D1, D2, N> operator-(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0519 MPlex<T, D1, D2, N> t = a;
0520 t -= b;
0521 return t;
0522 }
0523
0524 template <typename T, idx_t D1, idx_t D2, idx_t N>
0525 MPlex<T, D1, D2, N> operator*(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0526 MPlex<T, D1, D2, N> t = a;
0527 t *= b;
0528 return t;
0529 }
0530
0531 template <typename T, idx_t D1, idx_t D2, idx_t N>
0532 MPlex<T, D1, D2, N> operator/(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0533 MPlex<T, D1, D2, N> t = a;
0534 t /= b;
0535 return t;
0536 }
0537
0538 template <typename T, idx_t D1, idx_t D2, idx_t N>
0539 MPlex<T, D1, D2, N> operator+(const MPlex<T, D1, D2, N>& a, T b) {
0540 MPlex<T, D1, D2, N> t = a;
0541 t += b;
0542 return t;
0543 }
0544
0545 template <typename T, idx_t D1, idx_t D2, idx_t N>
0546 MPlex<T, D1, D2, N> operator-(const MPlex<T, D1, D2, N>& a, T b) {
0547 MPlex<T, D1, D2, N> t = a;
0548 t -= b;
0549 return t;
0550 }
0551
0552 template <typename T, idx_t D1, idx_t D2, idx_t N>
0553 MPlex<T, D1, D2, N> operator*(const MPlex<T, D1, D2, N>& a, T b) {
0554 MPlex<T, D1, D2, N> t = a;
0555 t *= b;
0556 return t;
0557 }
0558
0559 template <typename T, idx_t D1, idx_t D2, idx_t N>
0560 MPlex<T, D1, D2, N> operator/(const MPlex<T, D1, D2, N>& a, T b) {
0561 MPlex<T, D1, D2, N> t = a;
0562 t /= b;
0563 return t;
0564 }
0565
0566 template <typename T, idx_t D1, idx_t D2, idx_t N>
0567 MPlex<T, D1, D2, N> operator+(T a, const MPlex<T, D1, D2, N>& b) {
0568 MPlex<T, D1, D2, N> t = a;
0569 t += b;
0570 return t;
0571 }
0572
0573 template <typename T, idx_t D1, idx_t D2, idx_t N>
0574 MPlex<T, D1, D2, N> operator-(T a, const MPlex<T, D1, D2, N>& b) {
0575 MPlex<T, D1, D2, N> t = a;
0576 t -= b;
0577 return t;
0578 }
0579
0580 template <typename T, idx_t D1, idx_t D2, idx_t N>
0581 MPlex<T, D1, D2, N> operator*(T a, const MPlex<T, D1, D2, N>& b) {
0582 MPlex<T, D1, D2, N> t = a;
0583 t *= b;
0584 return t;
0585 }
0586
0587 template <typename T, idx_t D1, idx_t D2, idx_t N>
0588 MPlex<T, D1, D2, N> operator/(T a, const MPlex<T, D1, D2, N>& b) {
0589 MPlex<T, D1, D2, N> t = a;
0590 t /= b;
0591 return t;
0592 }
0593
0594 template <typename T, idx_t D1, idx_t D2, idx_t N>
0595 MPlex<T, D1, D2, N> abs(const MPlex<T, D1, D2, N>& a) {
0596 MPlex<T, D1, D2, N> t;
0597 return t.abs(a);
0598 }
0599
0600 template <typename T, idx_t D1, idx_t D2, idx_t N>
0601 MPlex<T, D1, D2, N> sqr(const MPlex<T, D1, D2, N>& a) {
0602 MPlex<T, D1, D2, N> t;
0603 return t.sqr(a);
0604 }
0605
0606
0607
0608
0609 template <typename T, idx_t D1, idx_t D2, idx_t N>
0610 MPlex<T, D1, D2, N> sqrt(const MPlex<T, D1, D2, N>& a) {
0611 MPlex<T, D1, D2, N> t;
0612 return t.sqrt(a);
0613 }
0614
0615 template <typename T, idx_t D1, idx_t D2, idx_t N>
0616 MPlex<T, D1, D2, N> hypot(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0617 MPlex<T, D1, D2, N> t;
0618 return t.hypot(a, b);
0619 }
0620
0621 template <typename T, idx_t D1, idx_t D2, idx_t N>
0622 MPlex<T, D1, D2, N> sin(const MPlex<T, D1, D2, N>& a) {
0623 MPlex<T, D1, D2, N> t;
0624 return t.sin(a);
0625 }
0626
0627 template <typename T, idx_t D1, idx_t D2, idx_t N>
0628 MPlex<T, D1, D2, N> cos(const MPlex<T, D1, D2, N>& a) {
0629 MPlex<T, D1, D2, N> t;
0630 return t.cos(a);
0631 }
0632
0633 template <typename T, idx_t D1, idx_t D2, idx_t N>
0634 void sincos(const MPlex<T, D1, D2, N>& a, MPlex<T, D1, D2, N>& s, MPlex<T, D1, D2, N>& c) {
0635 for (idx_t i = 0; i < a.kTotSize; ++i) {
0636 s.fArray[i] = std::sin(a.fArray[i]);
0637 c.fArray[i] = std::cos(a.fArray[i]);
0638 }
0639 }
0640
0641 template <typename T, idx_t D1, idx_t D2, idx_t N>
0642 MPlex<T, D1, D2, N> tan(const MPlex<T, D1, D2, N>& a) {
0643 MPlex<T, D1, D2, N> t;
0644 return t.tan(a);
0645 }
0646
0647 template <typename T, idx_t D1, idx_t D2, idx_t N>
0648 MPlex<T, D1, D2, N> atan2(const MPlex<T, D1, D2, N>& y, const MPlex<T, D1, D2, N>& x) {
0649 MPlex<T, D1, D2, N> t;
0650 return t.atan2(y, x);
0651 }
0652
0653
0654
0655
0656 #ifdef MPLEX_VDT
0657
0658 template <typename T, idx_t D1, idx_t D2, idx_t N>
0659 MPlex<T, D1, D2, N> fast_isqrt(const MPlex<T, D1, D2, N>& a) {
0660 MPlex<T, D1, D2, N> t;
0661 return t.fast_isqrt(a);
0662 }
0663
0664 template <typename T, idx_t D1, idx_t D2, idx_t N>
0665 MPlex<T, D1, D2, N> fast_sin(const MPlex<T, D1, D2, N>& a) {
0666 MPlex<T, D1, D2, N> t;
0667 return t.fast_sin(a);
0668 }
0669
0670 template <typename T, idx_t D1, idx_t D2, idx_t N>
0671 MPlex<T, D1, D2, N> fast_cos(const MPlex<T, D1, D2, N>& a) {
0672 MPlex<T, D1, D2, N> t;
0673 return t.fast_cos(a);
0674 }
0675
0676 template <typename T, idx_t D1, idx_t D2, idx_t N>
0677 void fast_sincos(const MPlex<T, D1, D2, N>& a, MPlex<T, D1, D2, N>& s, MPlex<T, D1, D2, N>& c) {
0678 a.fast_sincos(s, c);
0679 }
0680
0681 template <typename T, idx_t D1, idx_t D2, idx_t N>
0682 MPlex<T, D1, D2, N> fast_tan(const MPlex<T, D1, D2, N>& a) {
0683 MPlex<T, D1, D2, N> t;
0684 return t.fast_tan(a);
0685 }
0686
0687 template <typename T, idx_t D1, idx_t D2, idx_t N>
0688 MPlex<T, D1, D2, N> fast_atan2(const MPlex<T, D1, D2, N>& y, const MPlex<T, D1, D2, N>& x) {
0689 MPlex<T, D1, D2, N> t;
0690 return t.fast_atan2(y, x);
0691 }
0692
0693 #endif
0694
0695 template <typename T, idx_t D1, idx_t D2, idx_t N>
0696 void sincos4(const MPlex<T, D1, D2, N>& a, MPlex<T, D1, D2, N>& s, MPlex<T, D1, D2, N>& c) {
0697 a.sincos4(s, c);
0698 }
0699
0700
0701
0702 template <typename T, idx_t D1, idx_t D2, idx_t N>
0703 void min_max(const MPlex<T, D1, D2, N>& a,
0704 const MPlex<T, D1, D2, N>& b,
0705 MPlex<T, D1, D2, N>& min,
0706 MPlex<T, D1, D2, N>& max) {
0707 for (idx_t i = 0; i < a.kTotSize; ++i) {
0708 min.fArray[i] = std::min(a.fArray[i], b.fArray[i]);
0709 max.fArray[i] = std::max(a.fArray[i], b.fArray[i]);
0710 }
0711 }
0712
0713 template <typename T, idx_t D1, idx_t D2, idx_t N>
0714 MPlex<T, D1, D2, N> min(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0715 MPlex<T, D1, D2, N> t;
0716 for (idx_t i = 0; i < a.kTotSize; ++i) {
0717 t.fArray[i] = std::min(a.fArray[i], b.fArray[i]);
0718 }
0719 return t;
0720 }
0721
0722 template <typename T, idx_t D1, idx_t D2, idx_t N>
0723 MPlex<T, D1, D2, N> max(const MPlex<T, D1, D2, N>& a, const MPlex<T, D1, D2, N>& b) {
0724 MPlex<T, D1, D2, N> t;
0725 for (idx_t i = 0; i < a.kTotSize; ++i) {
0726 t.fArray[i] = std::max(a.fArray[i], b.fArray[i]);
0727 }
0728 return t;
0729 }
0730
0731
0732
0733
0734
0735 template <typename T, idx_t D1, idx_t D2, idx_t D3, idx_t N>
0736 void multiplyGeneral(const MPlex<T, D1, D2, N>& A, const MPlex<T, D2, D3, N>& B, MPlex<T, D1, D3, N>& C) {
0737 for (idx_t i = 0; i < D1; ++i) {
0738 for (idx_t j = 0; j < D3; ++j) {
0739 const idx_t ijo = N * (i * D3 + j);
0740
0741 #pragma omp simd
0742 for (idx_t n = 0; n < N; ++n) {
0743 C.fArray[ijo + n] = 0;
0744 }
0745
0746 for (idx_t k = 0; k < D2; ++k) {
0747 const idx_t iko = N * (i * D2 + k);
0748 const idx_t kjo = N * (k * D3 + j);
0749
0750 #pragma omp simd
0751 for (idx_t n = 0; n < N; ++n) {
0752 C.fArray[ijo + n] += A.fArray[iko + n] * B.fArray[kjo + n];
0753 }
0754 }
0755 }
0756 }
0757 }
0758
0759
0760
0761 template <typename T, idx_t D, idx_t N>
0762 struct MultiplyCls {
0763 static void multiply(const MPlex<T, D, D, N>& A, const MPlex<T, D, D, N>& B, MPlex<T, D, D, N>& C) {
0764 throw std::runtime_error("general multiplication not supported, well, call multiplyGeneral()");
0765 }
0766 };
0767
0768 template <typename T, idx_t N>
0769 struct MultiplyCls<T, 3, N> {
0770 static void multiply(const MPlex<T, 3, 3, N>& A, const MPlex<T, 3, 3, N>& B, MPlex<T, 3, 3, N>& C) {
0771 const T* a = A.fArray;
0772 ASSUME_ALIGNED(a, 64);
0773 const T* b = B.fArray;
0774 ASSUME_ALIGNED(b, 64);
0775 T* c = C.fArray;
0776 ASSUME_ALIGNED(c, 64);
0777
0778 #pragma omp simd
0779 for (idx_t n = 0; n < N; ++n) {
0780 c[0 * N + n] = a[0 * N + n] * b[0 * N + n] + a[1 * N + n] * b[3 * N + n] + a[2 * N + n] * b[6 * N + n];
0781 c[1 * N + n] = a[0 * N + n] * b[1 * N + n] + a[1 * N + n] * b[4 * N + n] + a[2 * N + n] * b[7 * N + n];
0782 c[2 * N + n] = a[0 * N + n] * b[2 * N + n] + a[1 * N + n] * b[5 * N + n] + a[2 * N + n] * b[8 * N + n];
0783 c[3 * N + n] = a[3 * N + n] * b[0 * N + n] + a[4 * N + n] * b[3 * N + n] + a[5 * N + n] * b[6 * N + n];
0784 c[4 * N + n] = a[3 * N + n] * b[1 * N + n] + a[4 * N + n] * b[4 * N + n] + a[5 * N + n] * b[7 * N + n];
0785 c[5 * N + n] = a[3 * N + n] * b[2 * N + n] + a[4 * N + n] * b[5 * N + n] + a[5 * N + n] * b[8 * N + n];
0786 c[6 * N + n] = a[6 * N + n] * b[0 * N + n] + a[7 * N + n] * b[3 * N + n] + a[8 * N + n] * b[6 * N + n];
0787 c[7 * N + n] = a[6 * N + n] * b[1 * N + n] + a[7 * N + n] * b[4 * N + n] + a[8 * N + n] * b[7 * N + n];
0788 c[8 * N + n] = a[6 * N + n] * b[2 * N + n] + a[7 * N + n] * b[5 * N + n] + a[8 * N + n] * b[8 * N + n];
0789 }
0790 }
0791 };
0792
0793 template <typename T, idx_t N>
0794 struct MultiplyCls<T, 6, N> {
0795 static void multiply(const MPlex<T, 6, 6, N>& A, const MPlex<T, 6, 6, N>& B, MPlex<T, 6, 6, N>& C) {
0796 const T* a = A.fArray;
0797 ASSUME_ALIGNED(a, 64);
0798 const T* b = B.fArray;
0799 ASSUME_ALIGNED(b, 64);
0800 T* c = C.fArray;
0801 ASSUME_ALIGNED(c, 64);
0802 #pragma omp simd
0803 for (idx_t n = 0; n < N; ++n) {
0804 c[0 * N + n] = a[0 * N + n] * b[0 * N + n] + a[1 * N + n] * b[6 * N + n] + a[2 * N + n] * b[12 * N + n] +
0805 a[3 * N + n] * b[18 * N + n] + a[4 * N + n] * b[24 * N + n] + a[5 * N + n] * b[30 * N + n];
0806 c[1 * N + n] = a[0 * N + n] * b[1 * N + n] + a[1 * N + n] * b[7 * N + n] + a[2 * N + n] * b[13 * N + n] +
0807 a[3 * N + n] * b[19 * N + n] + a[4 * N + n] * b[25 * N + n] + a[5 * N + n] * b[31 * N + n];
0808 c[2 * N + n] = a[0 * N + n] * b[2 * N + n] + a[1 * N + n] * b[8 * N + n] + a[2 * N + n] * b[14 * N + n] +
0809 a[3 * N + n] * b[20 * N + n] + a[4 * N + n] * b[26 * N + n] + a[5 * N + n] * b[32 * N + n];
0810 c[3 * N + n] = a[0 * N + n] * b[3 * N + n] + a[1 * N + n] * b[9 * N + n] + a[2 * N + n] * b[15 * N + n] +
0811 a[3 * N + n] * b[21 * N + n] + a[4 * N + n] * b[27 * N + n] + a[5 * N + n] * b[33 * N + n];
0812 c[4 * N + n] = a[0 * N + n] * b[4 * N + n] + a[1 * N + n] * b[10 * N + n] + a[2 * N + n] * b[16 * N + n] +
0813 a[3 * N + n] * b[22 * N + n] + a[4 * N + n] * b[28 * N + n] + a[5 * N + n] * b[34 * N + n];
0814 c[5 * N + n] = a[0 * N + n] * b[5 * N + n] + a[1 * N + n] * b[11 * N + n] + a[2 * N + n] * b[17 * N + n] +
0815 a[3 * N + n] * b[23 * N + n] + a[4 * N + n] * b[29 * N + n] + a[5 * N + n] * b[35 * N + n];
0816 c[6 * N + n] = a[6 * N + n] * b[0 * N + n] + a[7 * N + n] * b[6 * N + n] + a[8 * N + n] * b[12 * N + n] +
0817 a[9 * N + n] * b[18 * N + n] + a[10 * N + n] * b[24 * N + n] + a[11 * N + n] * b[30 * N + n];
0818 c[7 * N + n] = a[6 * N + n] * b[1 * N + n] + a[7 * N + n] * b[7 * N + n] + a[8 * N + n] * b[13 * N + n] +
0819 a[9 * N + n] * b[19 * N + n] + a[10 * N + n] * b[25 * N + n] + a[11 * N + n] * b[31 * N + n];
0820 c[8 * N + n] = a[6 * N + n] * b[2 * N + n] + a[7 * N + n] * b[8 * N + n] + a[8 * N + n] * b[14 * N + n] +
0821 a[9 * N + n] * b[20 * N + n] + a[10 * N + n] * b[26 * N + n] + a[11 * N + n] * b[32 * N + n];
0822 c[9 * N + n] = a[6 * N + n] * b[3 * N + n] + a[7 * N + n] * b[9 * N + n] + a[8 * N + n] * b[15 * N + n] +
0823 a[9 * N + n] * b[21 * N + n] + a[10 * N + n] * b[27 * N + n] + a[11 * N + n] * b[33 * N + n];
0824 c[10 * N + n] = a[6 * N + n] * b[4 * N + n] + a[7 * N + n] * b[10 * N + n] + a[8 * N + n] * b[16 * N + n] +
0825 a[9 * N + n] * b[22 * N + n] + a[10 * N + n] * b[28 * N + n] + a[11 * N + n] * b[34 * N + n];
0826 c[11 * N + n] = a[6 * N + n] * b[5 * N + n] + a[7 * N + n] * b[11 * N + n] + a[8 * N + n] * b[17 * N + n] +
0827 a[9 * N + n] * b[23 * N + n] + a[10 * N + n] * b[29 * N + n] + a[11 * N + n] * b[35 * N + n];
0828 c[12 * N + n] = a[12 * N + n] * b[0 * N + n] + a[13 * N + n] * b[6 * N + n] + a[14 * N + n] * b[12 * N + n] +
0829 a[15 * N + n] * b[18 * N + n] + a[16 * N + n] * b[24 * N + n] + a[17 * N + n] * b[30 * N + n];
0830 c[13 * N + n] = a[12 * N + n] * b[1 * N + n] + a[13 * N + n] * b[7 * N + n] + a[14 * N + n] * b[13 * N + n] +
0831 a[15 * N + n] * b[19 * N + n] + a[16 * N + n] * b[25 * N + n] + a[17 * N + n] * b[31 * N + n];
0832 c[14 * N + n] = a[12 * N + n] * b[2 * N + n] + a[13 * N + n] * b[8 * N + n] + a[14 * N + n] * b[14 * N + n] +
0833 a[15 * N + n] * b[20 * N + n] + a[16 * N + n] * b[26 * N + n] + a[17 * N + n] * b[32 * N + n];
0834 c[15 * N + n] = a[12 * N + n] * b[3 * N + n] + a[13 * N + n] * b[9 * N + n] + a[14 * N + n] * b[15 * N + n] +
0835 a[15 * N + n] * b[21 * N + n] + a[16 * N + n] * b[27 * N + n] + a[17 * N + n] * b[33 * N + n];
0836 c[16 * N + n] = a[12 * N + n] * b[4 * N + n] + a[13 * N + n] * b[10 * N + n] + a[14 * N + n] * b[16 * N + n] +
0837 a[15 * N + n] * b[22 * N + n] + a[16 * N + n] * b[28 * N + n] + a[17 * N + n] * b[34 * N + n];
0838 c[17 * N + n] = a[12 * N + n] * b[5 * N + n] + a[13 * N + n] * b[11 * N + n] + a[14 * N + n] * b[17 * N + n] +
0839 a[15 * N + n] * b[23 * N + n] + a[16 * N + n] * b[29 * N + n] + a[17 * N + n] * b[35 * N + n];
0840 c[18 * N + n] = a[18 * N + n] * b[0 * N + n] + a[19 * N + n] * b[6 * N + n] + a[20 * N + n] * b[12 * N + n] +
0841 a[21 * N + n] * b[18 * N + n] + a[22 * N + n] * b[24 * N + n] + a[23 * N + n] * b[30 * N + n];
0842 c[19 * N + n] = a[18 * N + n] * b[1 * N + n] + a[19 * N + n] * b[7 * N + n] + a[20 * N + n] * b[13 * N + n] +
0843 a[21 * N + n] * b[19 * N + n] + a[22 * N + n] * b[25 * N + n] + a[23 * N + n] * b[31 * N + n];
0844 c[20 * N + n] = a[18 * N + n] * b[2 * N + n] + a[19 * N + n] * b[8 * N + n] + a[20 * N + n] * b[14 * N + n] +
0845 a[21 * N + n] * b[20 * N + n] + a[22 * N + n] * b[26 * N + n] + a[23 * N + n] * b[32 * N + n];
0846 c[21 * N + n] = a[18 * N + n] * b[3 * N + n] + a[19 * N + n] * b[9 * N + n] + a[20 * N + n] * b[15 * N + n] +
0847 a[21 * N + n] * b[21 * N + n] + a[22 * N + n] * b[27 * N + n] + a[23 * N + n] * b[33 * N + n];
0848 c[22 * N + n] = a[18 * N + n] * b[4 * N + n] + a[19 * N + n] * b[10 * N + n] + a[20 * N + n] * b[16 * N + n] +
0849 a[21 * N + n] * b[22 * N + n] + a[22 * N + n] * b[28 * N + n] + a[23 * N + n] * b[34 * N + n];
0850 c[23 * N + n] = a[18 * N + n] * b[5 * N + n] + a[19 * N + n] * b[11 * N + n] + a[20 * N + n] * b[17 * N + n] +
0851 a[21 * N + n] * b[23 * N + n] + a[22 * N + n] * b[29 * N + n] + a[23 * N + n] * b[35 * N + n];
0852 c[24 * N + n] = a[24 * N + n] * b[0 * N + n] + a[25 * N + n] * b[6 * N + n] + a[26 * N + n] * b[12 * N + n] +
0853 a[27 * N + n] * b[18 * N + n] + a[28 * N + n] * b[24 * N + n] + a[29 * N + n] * b[30 * N + n];
0854 c[25 * N + n] = a[24 * N + n] * b[1 * N + n] + a[25 * N + n] * b[7 * N + n] + a[26 * N + n] * b[13 * N + n] +
0855 a[27 * N + n] * b[19 * N + n] + a[28 * N + n] * b[25 * N + n] + a[29 * N + n] * b[31 * N + n];
0856 c[26 * N + n] = a[24 * N + n] * b[2 * N + n] + a[25 * N + n] * b[8 * N + n] + a[26 * N + n] * b[14 * N + n] +
0857 a[27 * N + n] * b[20 * N + n] + a[28 * N + n] * b[26 * N + n] + a[29 * N + n] * b[32 * N + n];
0858 c[27 * N + n] = a[24 * N + n] * b[3 * N + n] + a[25 * N + n] * b[9 * N + n] + a[26 * N + n] * b[15 * N + n] +
0859 a[27 * N + n] * b[21 * N + n] + a[28 * N + n] * b[27 * N + n] + a[29 * N + n] * b[33 * N + n];
0860 c[28 * N + n] = a[24 * N + n] * b[4 * N + n] + a[25 * N + n] * b[10 * N + n] + a[26 * N + n] * b[16 * N + n] +
0861 a[27 * N + n] * b[22 * N + n] + a[28 * N + n] * b[28 * N + n] + a[29 * N + n] * b[34 * N + n];
0862 c[29 * N + n] = a[24 * N + n] * b[5 * N + n] + a[25 * N + n] * b[11 * N + n] + a[26 * N + n] * b[17 * N + n] +
0863 a[27 * N + n] * b[23 * N + n] + a[28 * N + n] * b[29 * N + n] + a[29 * N + n] * b[35 * N + n];
0864 c[30 * N + n] = a[30 * N + n] * b[0 * N + n] + a[31 * N + n] * b[6 * N + n] + a[32 * N + n] * b[12 * N + n] +
0865 a[33 * N + n] * b[18 * N + n] + a[34 * N + n] * b[24 * N + n] + a[35 * N + n] * b[30 * N + n];
0866 c[31 * N + n] = a[30 * N + n] * b[1 * N + n] + a[31 * N + n] * b[7 * N + n] + a[32 * N + n] * b[13 * N + n] +
0867 a[33 * N + n] * b[19 * N + n] + a[34 * N + n] * b[25 * N + n] + a[35 * N + n] * b[31 * N + n];
0868 c[32 * N + n] = a[30 * N + n] * b[2 * N + n] + a[31 * N + n] * b[8 * N + n] + a[32 * N + n] * b[14 * N + n] +
0869 a[33 * N + n] * b[20 * N + n] + a[34 * N + n] * b[26 * N + n] + a[35 * N + n] * b[32 * N + n];
0870 c[33 * N + n] = a[30 * N + n] * b[3 * N + n] + a[31 * N + n] * b[9 * N + n] + a[32 * N + n] * b[15 * N + n] +
0871 a[33 * N + n] * b[21 * N + n] + a[34 * N + n] * b[27 * N + n] + a[35 * N + n] * b[33 * N + n];
0872 c[34 * N + n] = a[30 * N + n] * b[4 * N + n] + a[31 * N + n] * b[10 * N + n] + a[32 * N + n] * b[16 * N + n] +
0873 a[33 * N + n] * b[22 * N + n] + a[34 * N + n] * b[28 * N + n] + a[35 * N + n] * b[34 * N + n];
0874 c[35 * N + n] = a[30 * N + n] * b[5 * N + n] + a[31 * N + n] * b[11 * N + n] + a[32 * N + n] * b[17 * N + n] +
0875 a[33 * N + n] * b[23 * N + n] + a[34 * N + n] * b[29 * N + n] + a[35 * N + n] * b[35 * N + n];
0876 }
0877 }
0878 };
0879
0880 template <typename T, idx_t D, idx_t N>
0881 void multiply(const MPlex<T, D, D, N>& A, const MPlex<T, D, D, N>& B, MPlex<T, D, D, N>& C) {
0882 #ifdef DEBUG
0883 printf("Multipl %d %d\n", D, N);
0884 #endif
0885
0886 MultiplyCls<T, D, N>::multiply(A, B, C);
0887 }
0888
0889
0890
0891
0892
0893 template <typename T, idx_t D, idx_t N>
0894 struct CramerInverter {
0895 static void invert(MPlex<T, D, D, N>& A, double* determ = nullptr) {
0896 throw std::runtime_error("general cramer inversion not supported");
0897 }
0898 };
0899
0900 template <typename T, idx_t N>
0901 struct CramerInverter<T, 2, N> {
0902 static void invert(MPlex<T, 2, 2, N>& A, double* determ = nullptr) {
0903 typedef T TT;
0904
0905 T* a = A.fArray;
0906 ASSUME_ALIGNED(a, 64);
0907
0908 #pragma omp simd
0909 for (idx_t n = 0; n < N; ++n) {
0910
0911 const double det = (double)a[0 * N + n] * a[3 * N + n] - (double)a[2 * N + n] * a[1 * N + n];
0912 if (determ)
0913 determ[n] = det;
0914
0915 const TT s = TT(1) / det;
0916 const TT tmp = s * a[3 * N + n];
0917 a[1 * N + n] *= -s;
0918 a[2 * N + n] *= -s;
0919 a[3 * N + n] = s * a[0 * N + n];
0920 a[0 * N + n] = tmp;
0921 }
0922 }
0923 };
0924
0925 template <typename T, idx_t N>
0926 struct CramerInverter<T, 3, N> {
0927 static void invert(MPlex<T, 3, 3, N>& A, double* determ = nullptr) {
0928 typedef T TT;
0929
0930 T* a = A.fArray;
0931 ASSUME_ALIGNED(a, 64);
0932
0933 #pragma omp simd
0934 for (idx_t n = 0; n < N; ++n) {
0935 const TT c00 = a[4 * N + n] * a[8 * N + n] - a[5 * N + n] * a[7 * N + n];
0936 const TT c01 = a[5 * N + n] * a[6 * N + n] - a[3 * N + n] * a[8 * N + n];
0937 const TT c02 = a[3 * N + n] * a[7 * N + n] - a[4 * N + n] * a[6 * N + n];
0938 const TT c10 = a[7 * N + n] * a[2 * N + n] - a[8 * N + n] * a[1 * N + n];
0939 const TT c11 = a[8 * N + n] * a[0 * N + n] - a[6 * N + n] * a[2 * N + n];
0940 const TT c12 = a[6 * N + n] * a[1 * N + n] - a[7 * N + n] * a[0 * N + n];
0941 const TT c20 = a[1 * N + n] * a[5 * N + n] - a[2 * N + n] * a[4 * N + n];
0942 const TT c21 = a[2 * N + n] * a[3 * N + n] - a[0 * N + n] * a[5 * N + n];
0943 const TT c22 = a[0 * N + n] * a[4 * N + n] - a[1 * N + n] * a[3 * N + n];
0944
0945
0946 const double det = (double)a[0 * N + n] * c00 + (double)a[1 * N + n] * c01 + (double)a[2 * N + n] * c02;
0947 if (determ)
0948 determ[n] = det;
0949
0950 const TT s = TT(1) / det;
0951 a[0 * N + n] = s * c00;
0952 a[1 * N + n] = s * c10;
0953 a[2 * N + n] = s * c20;
0954 a[3 * N + n] = s * c01;
0955 a[4 * N + n] = s * c11;
0956 a[5 * N + n] = s * c21;
0957 a[6 * N + n] = s * c02;
0958 a[7 * N + n] = s * c12;
0959 a[8 * N + n] = s * c22;
0960 }
0961 }
0962 };
0963
0964 template <typename T, idx_t D, idx_t N>
0965 void invertCramer(MPlex<T, D, D, N>& A, double* determ = nullptr) {
0966 CramerInverter<T, D, N>::invert(A, determ);
0967 }
0968
0969
0970
0971
0972
0973 template <typename T, idx_t D, idx_t N>
0974 struct CholeskyInverter {
0975 static void invert(MPlex<T, D, D, N>& A) { throw std::runtime_error("general cholesky inversion not supported"); }
0976 };
0977
0978 template <typename T, idx_t N>
0979 struct CholeskyInverter<T, 3, N> {
0980
0981
0982 static void invert(MPlex<T, 3, 3, N>& A) {
0983 typedef T TT;
0984
0985 T* a = A.fArray;
0986 ASSUME_ALIGNED(a, 64);
0987
0988 #pragma omp simd
0989 for (idx_t n = 0; n < N; ++n) {
0990 TT l0 = std::sqrt(T(1) / a[0 * N + n]);
0991 TT l1 = a[3 * N + n] * l0;
0992 TT l2 = a[4 * N + n] - l1 * l1;
0993 l2 = std::sqrt(T(1) / l2);
0994 TT l3 = a[6 * N + n] * l0;
0995 TT l4 = (a[7 * N + n] - l1 * l3) * l2;
0996 TT l5 = a[8 * N + n] - (l3 * l3 + l4 * l4);
0997 l5 = std::sqrt(T(1) / l5);
0998
0999
1000
1001 l3 = (l1 * l4 * l2 - l3) * l0 * l5;
1002 l1 = -l1 * l0 * l2;
1003 l4 = -l4 * l2 * l5;
1004
1005 a[0 * N + n] = l3 * l3 + l1 * l1 + l0 * l0;
1006 a[1 * N + n] = a[3 * N + n] = l3 * l4 + l1 * l2;
1007 a[4 * N + n] = l4 * l4 + l2 * l2;
1008 a[2 * N + n] = a[6 * N + n] = l3 * l5;
1009 a[5 * N + n] = a[7 * N + n] = l4 * l5;
1010 a[8 * N + n] = l5 * l5;
1011
1012
1013
1014 }
1015 }
1016 };
1017
1018 template <typename T, idx_t D, idx_t N>
1019 void invertCholesky(MPlex<T, D, D, N>& A) {
1020 CholeskyInverter<T, D, N>::invert(A);
1021 }
1022
1023 }
1024
1025 #endif