File indexing completed on 2022-02-21 23:14:15
0001 #ifndef RecoTracker_MkFitCore_src_Matriplex_Matriplex_h
0002 #define RecoTracker_MkFitCore_src_Matriplex_Matriplex_h
0003
0004 #include "MatriplexCommon.h"
0005
0006 namespace Matriplex {
0007
0008
0009
0010 template <typename T, idx_t D1, idx_t D2, idx_t N>
0011 class Matriplex {
0012 public:
0013 typedef T value_type;
0014
0015
0016 static constexpr int kRows = D1;
0017
0018 static constexpr int kCols = D2;
0019
0020 static constexpr int kSize = D1 * D2;
0021
0022 static constexpr int kTotSize = N * kSize;
0023
0024 T fArray[kTotSize] __attribute__((aligned(64)));
0025
0026 Matriplex() {}
0027 Matriplex(T v) { setVal(v); }
0028
0029 idx_t plexSize() const { return N; }
0030
0031 void setVal(T v) {
0032 for (idx_t i = 0; i < kTotSize; ++i) {
0033 fArray[i] = v;
0034 }
0035 }
0036
0037 void add(const Matriplex& v) {
0038 for (idx_t i = 0; i < kTotSize; ++i) {
0039 fArray[i] += v.fArray[i];
0040 }
0041 }
0042
0043 void scale(T scale) {
0044 for (idx_t i = 0; i < kTotSize; ++i) {
0045 fArray[i] *= scale;
0046 }
0047 }
0048
0049 T operator[](idx_t xx) const { return fArray[xx]; }
0050 T& operator[](idx_t xx) { return fArray[xx]; }
0051
0052 const T& constAt(idx_t n, idx_t i, idx_t j) const { return fArray[(i * D2 + j) * N + n]; }
0053
0054 T& At(idx_t n, idx_t i, idx_t j) { return fArray[(i * D2 + j) * N + n]; }
0055
0056 T& operator()(idx_t n, idx_t i, idx_t j) { return fArray[(i * D2 + j) * N + n]; }
0057 const T& operator()(idx_t n, idx_t i, idx_t j) const { return fArray[(i * D2 + j) * N + n]; }
0058
0059 Matriplex& operator=(const Matriplex& m) {
0060 memcpy(fArray, m.fArray, sizeof(T) * kTotSize);
0061 return *this;
0062 }
0063
0064 void copySlot(idx_t n, const Matriplex& m) {
0065 for (idx_t i = n; i < kTotSize; i += N) {
0066 fArray[i] = m.fArray[i];
0067 }
0068 }
0069
0070 void copyIn(idx_t n, const T* arr) {
0071 for (idx_t i = n; i < kTotSize; i += N) {
0072 fArray[i] = *(arr++);
0073 }
0074 }
0075
0076 void copyIn(idx_t n, const Matriplex& m, idx_t in) {
0077 for (idx_t i = n; i < kTotSize; i += N, in += N) {
0078 fArray[i] = m[in];
0079 }
0080 }
0081
0082 void copy(idx_t n, idx_t in) {
0083 for (idx_t i = n; i < kTotSize; i += N, in += N) {
0084 fArray[i] = fArray[in];
0085 }
0086 }
0087
0088 #if defined(AVX512_INTRINSICS)
0089
0090 template <typename U>
0091 void slurpIn(const T* arr, __m512i& vi, const U&, const int N_proc = N) {
0092
0093
0094 const __m512 src = {0};
0095 const __mmask16 k = N_proc == N ? -1 : (1 << N_proc) - 1;
0096
0097 for (int i = 0; i < kSize; ++i, ++arr) {
0098
0099
0100 __m512 reg = _mm512_mask_i32gather_ps(src, k, vi, arr, sizeof(U));
0101 _mm512_mask_store_ps(&fArray[i * N], k, reg);
0102 }
0103 }
0104
0105
0106
0107 void ChewIn(const char* arr, int off, int vi[N], const char* tmp, __m512i& ui) {
0108
0109
0110 for (int i = 0; i < N; ++i) {
0111 __m512 reg = _mm512_load_ps(arr + vi[i]);
0112 _mm512_store_ps((void*)(tmp + 64 * i), reg);
0113 }
0114
0115 for (int i = 0; i < kSize; ++i) {
0116 __m512 reg = _mm512_i32gather_ps(ui, tmp + off + i * sizeof(T), 1);
0117 _mm512_store_ps(&fArray[i * N], reg);
0118 }
0119 }
0120
0121 void Contaginate(const char* arr, int vi[N], const char* tmp) {
0122
0123
0124 for (int i = 0; i < N; ++i) {
0125 __m512 reg = _mm512_load_ps(arr + vi[i]);
0126 _mm512_store_ps((void*)(tmp + 64 * i), reg);
0127 }
0128 }
0129
0130 void Plexify(const char* tmp, __m512i& ui) {
0131 for (int i = 0; i < kSize; ++i) {
0132 __m512 reg = _mm512_i32gather_ps(ui, tmp + i * sizeof(T), 1);
0133 _mm512_store_ps(&fArray[i * N], reg);
0134 }
0135 }
0136
0137 #elif defined(AVX2_INTRINSICS)
0138
0139 template <typename U>
0140 void slurpIn(const T* arr, __m256i& vi, const U&, const int N_proc = N) {
0141
0142
0143
0144 const __m256 src = {0};
0145
0146 __m256i k = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
0147 __m256i k_sel = _mm256_set1_epi32(N_proc);
0148 __m256i k_master = _mm256_cmpgt_epi32(k_sel, k);
0149
0150 k = k_master;
0151 for (int i = 0; i < kSize; ++i, ++arr) {
0152 __m256 reg = _mm256_mask_i32gather_ps(src, (float*)arr, vi, (__m256)k, sizeof(U));
0153
0154 k = k_master;
0155 _mm256_maskstore_ps((float*)&fArray[i * N], k, reg);
0156 }
0157 }
0158
0159 #else
0160
0161 void slurpIn(const T* arr, int vi[N], const int N_proc = N) {
0162
0163 if (N_proc == N) {
0164 for (int i = 0; i < kSize; ++i) {
0165 for (int j = 0; j < N; ++j) {
0166 fArray[i * N + j] = *(arr + i + vi[j]);
0167 }
0168 }
0169 } else {
0170 for (int i = 0; i < kSize; ++i) {
0171 for (int j = 0; j < N_proc; ++j) {
0172 fArray[i * N + j] = *(arr + i + vi[j]);
0173 }
0174 }
0175 }
0176 }
0177
0178 #endif
0179
0180 void copyOut(idx_t n, T* arr) const {
0181 for (idx_t i = n; i < kTotSize; i += N) {
0182 *(arr++) = fArray[i];
0183 }
0184 }
0185 };
0186
0187 template <typename T, idx_t D1, idx_t D2, idx_t N>
0188 using MPlex = Matriplex<T, D1, D2, N>;
0189
0190
0191
0192
0193
0194 template <typename T, idx_t D1, idx_t D2, idx_t D3, idx_t N>
0195 void multiplyGeneral(const MPlex<T, D1, D2, N>& A, const MPlex<T, D2, D3, N>& B, MPlex<T, D1, D3, N>& C) {
0196 for (idx_t i = 0; i < D1; ++i) {
0197 for (idx_t j = 0; j < D3; ++j) {
0198 const idx_t ijo = N * (i * D3 + j);
0199
0200 #pragma omp simd
0201 for (idx_t n = 0; n < N; ++n) {
0202 C.fArray[ijo + n] = 0;
0203 }
0204
0205 for (idx_t k = 0; k < D2; ++k) {
0206 const idx_t iko = N * (i * D2 + k);
0207 const idx_t kjo = N * (k * D3 + j);
0208
0209 #pragma omp simd
0210 for (idx_t n = 0; n < N; ++n) {
0211 C.fArray[ijo + n] += A.fArray[iko + n] * B.fArray[kjo + n];
0212 }
0213 }
0214 }
0215 }
0216 }
0217
0218
0219
0220 template <typename T, idx_t D, idx_t N>
0221 struct MultiplyCls {
0222 static void multiply(const MPlex<T, D, D, N>& A, const MPlex<T, D, D, N>& B, MPlex<T, D, D, N>& C) {
0223 throw std::runtime_error("general multiplication not supported, well, call multiplyGeneral()");
0224 }
0225 };
0226
0227 template <typename T, idx_t N>
0228 struct MultiplyCls<T, 3, N> {
0229 static void multiply(const MPlex<T, 3, 3, N>& A, const MPlex<T, 3, 3, N>& B, MPlex<T, 3, 3, N>& C) {
0230 const T* a = A.fArray;
0231 ASSUME_ALIGNED(a, 64);
0232 const T* b = B.fArray;
0233 ASSUME_ALIGNED(b, 64);
0234 T* c = C.fArray;
0235 ASSUME_ALIGNED(c, 64);
0236
0237 #pragma omp simd
0238 for (idx_t n = 0; n < N; ++n) {
0239 c[0 * N + n] = a[0 * N + n] * b[0 * N + n] + a[1 * N + n] * b[3 * N + n] + a[2 * N + n] * b[6 * N + n];
0240 c[1 * N + n] = a[0 * N + n] * b[1 * N + n] + a[1 * N + n] * b[4 * N + n] + a[2 * N + n] * b[7 * N + n];
0241 c[2 * N + n] = a[0 * N + n] * b[2 * N + n] + a[1 * N + n] * b[5 * N + n] + a[2 * N + n] * b[8 * N + n];
0242 c[3 * N + n] = a[3 * N + n] * b[0 * N + n] + a[4 * N + n] * b[3 * N + n] + a[5 * N + n] * b[6 * N + n];
0243 c[4 * N + n] = a[3 * N + n] * b[1 * N + n] + a[4 * N + n] * b[4 * N + n] + a[5 * N + n] * b[7 * N + n];
0244 c[5 * N + n] = a[3 * N + n] * b[2 * N + n] + a[4 * N + n] * b[5 * N + n] + a[5 * N + n] * b[8 * N + n];
0245 c[6 * N + n] = a[6 * N + n] * b[0 * N + n] + a[7 * N + n] * b[3 * N + n] + a[8 * N + n] * b[6 * N + n];
0246 c[7 * N + n] = a[6 * N + n] * b[1 * N + n] + a[7 * N + n] * b[4 * N + n] + a[8 * N + n] * b[7 * N + n];
0247 c[8 * N + n] = a[6 * N + n] * b[2 * N + n] + a[7 * N + n] * b[5 * N + n] + a[8 * N + n] * b[8 * N + n];
0248 }
0249 }
0250 };
0251
0252 template <typename T, idx_t N>
0253 struct MultiplyCls<T, 6, N> {
0254 static void multiply(const MPlex<T, 6, 6, N>& A, const MPlex<T, 6, 6, N>& B, MPlex<T, 6, 6, N>& C) {
0255 const T* a = A.fArray;
0256 ASSUME_ALIGNED(a, 64);
0257 const T* b = B.fArray;
0258 ASSUME_ALIGNED(b, 64);
0259 T* c = C.fArray;
0260 ASSUME_ALIGNED(c, 64);
0261 #pragma omp simd
0262 for (idx_t n = 0; n < N; ++n) {
0263 c[0 * N + n] = a[0 * N + n] * b[0 * N + n] + a[1 * N + n] * b[6 * N + n] + a[2 * N + n] * b[12 * N + n] +
0264 a[3 * N + n] * b[18 * N + n] + a[4 * N + n] * b[24 * N + n] + a[5 * N + n] * b[30 * N + n];
0265 c[1 * N + n] = a[0 * N + n] * b[1 * N + n] + a[1 * N + n] * b[7 * N + n] + a[2 * N + n] * b[13 * N + n] +
0266 a[3 * N + n] * b[19 * N + n] + a[4 * N + n] * b[25 * N + n] + a[5 * N + n] * b[31 * N + n];
0267 c[2 * N + n] = a[0 * N + n] * b[2 * N + n] + a[1 * N + n] * b[8 * N + n] + a[2 * N + n] * b[14 * N + n] +
0268 a[3 * N + n] * b[20 * N + n] + a[4 * N + n] * b[26 * N + n] + a[5 * N + n] * b[32 * N + n];
0269 c[3 * N + n] = a[0 * N + n] * b[3 * N + n] + a[1 * N + n] * b[9 * N + n] + a[2 * N + n] * b[15 * N + n] +
0270 a[3 * N + n] * b[21 * N + n] + a[4 * N + n] * b[27 * N + n] + a[5 * N + n] * b[33 * N + n];
0271 c[4 * N + n] = a[0 * N + n] * b[4 * N + n] + a[1 * N + n] * b[10 * N + n] + a[2 * N + n] * b[16 * N + n] +
0272 a[3 * N + n] * b[22 * N + n] + a[4 * N + n] * b[28 * N + n] + a[5 * N + n] * b[34 * N + n];
0273 c[5 * N + n] = a[0 * N + n] * b[5 * N + n] + a[1 * N + n] * b[11 * N + n] + a[2 * N + n] * b[17 * N + n] +
0274 a[3 * N + n] * b[23 * N + n] + a[4 * N + n] * b[29 * N + n] + a[5 * N + n] * b[35 * N + n];
0275 c[6 * N + n] = a[6 * N + n] * b[0 * N + n] + a[7 * N + n] * b[6 * N + n] + a[8 * N + n] * b[12 * N + n] +
0276 a[9 * N + n] * b[18 * N + n] + a[10 * N + n] * b[24 * N + n] + a[11 * N + n] * b[30 * N + n];
0277 c[7 * N + n] = a[6 * N + n] * b[1 * N + n] + a[7 * N + n] * b[7 * N + n] + a[8 * N + n] * b[13 * N + n] +
0278 a[9 * N + n] * b[19 * N + n] + a[10 * N + n] * b[25 * N + n] + a[11 * N + n] * b[31 * N + n];
0279 c[8 * N + n] = a[6 * N + n] * b[2 * N + n] + a[7 * N + n] * b[8 * N + n] + a[8 * N + n] * b[14 * N + n] +
0280 a[9 * N + n] * b[20 * N + n] + a[10 * N + n] * b[26 * N + n] + a[11 * N + n] * b[32 * N + n];
0281 c[9 * N + n] = a[6 * N + n] * b[3 * N + n] + a[7 * N + n] * b[9 * N + n] + a[8 * N + n] * b[15 * N + n] +
0282 a[9 * N + n] * b[21 * N + n] + a[10 * N + n] * b[27 * N + n] + a[11 * N + n] * b[33 * N + n];
0283 c[10 * N + n] = a[6 * N + n] * b[4 * N + n] + a[7 * N + n] * b[10 * N + n] + a[8 * N + n] * b[16 * N + n] +
0284 a[9 * N + n] * b[22 * N + n] + a[10 * N + n] * b[28 * N + n] + a[11 * N + n] * b[34 * N + n];
0285 c[11 * N + n] = a[6 * N + n] * b[5 * N + n] + a[7 * N + n] * b[11 * N + n] + a[8 * N + n] * b[17 * N + n] +
0286 a[9 * N + n] * b[23 * N + n] + a[10 * N + n] * b[29 * N + n] + a[11 * N + n] * b[35 * N + n];
0287 c[12 * N + n] = a[12 * N + n] * b[0 * N + n] + a[13 * N + n] * b[6 * N + n] + a[14 * N + n] * b[12 * N + n] +
0288 a[15 * N + n] * b[18 * N + n] + a[16 * N + n] * b[24 * N + n] + a[17 * N + n] * b[30 * N + n];
0289 c[13 * N + n] = a[12 * N + n] * b[1 * N + n] + a[13 * N + n] * b[7 * N + n] + a[14 * N + n] * b[13 * N + n] +
0290 a[15 * N + n] * b[19 * N + n] + a[16 * N + n] * b[25 * N + n] + a[17 * N + n] * b[31 * N + n];
0291 c[14 * N + n] = a[12 * N + n] * b[2 * N + n] + a[13 * N + n] * b[8 * N + n] + a[14 * N + n] * b[14 * N + n] +
0292 a[15 * N + n] * b[20 * N + n] + a[16 * N + n] * b[26 * N + n] + a[17 * N + n] * b[32 * N + n];
0293 c[15 * N + n] = a[12 * N + n] * b[3 * N + n] + a[13 * N + n] * b[9 * N + n] + a[14 * N + n] * b[15 * N + n] +
0294 a[15 * N + n] * b[21 * N + n] + a[16 * N + n] * b[27 * N + n] + a[17 * N + n] * b[33 * N + n];
0295 c[16 * N + n] = a[12 * N + n] * b[4 * N + n] + a[13 * N + n] * b[10 * N + n] + a[14 * N + n] * b[16 * N + n] +
0296 a[15 * N + n] * b[22 * N + n] + a[16 * N + n] * b[28 * N + n] + a[17 * N + n] * b[34 * N + n];
0297 c[17 * N + n] = a[12 * N + n] * b[5 * N + n] + a[13 * N + n] * b[11 * N + n] + a[14 * N + n] * b[17 * N + n] +
0298 a[15 * N + n] * b[23 * N + n] + a[16 * N + n] * b[29 * N + n] + a[17 * N + n] * b[35 * N + n];
0299 c[18 * N + n] = a[18 * N + n] * b[0 * N + n] + a[19 * N + n] * b[6 * N + n] + a[20 * N + n] * b[12 * N + n] +
0300 a[21 * N + n] * b[18 * N + n] + a[22 * N + n] * b[24 * N + n] + a[23 * N + n] * b[30 * N + n];
0301 c[19 * N + n] = a[18 * N + n] * b[1 * N + n] + a[19 * N + n] * b[7 * N + n] + a[20 * N + n] * b[13 * N + n] +
0302 a[21 * N + n] * b[19 * N + n] + a[22 * N + n] * b[25 * N + n] + a[23 * N + n] * b[31 * N + n];
0303 c[20 * N + n] = a[18 * N + n] * b[2 * N + n] + a[19 * N + n] * b[8 * N + n] + a[20 * N + n] * b[14 * N + n] +
0304 a[21 * N + n] * b[20 * N + n] + a[22 * N + n] * b[26 * N + n] + a[23 * N + n] * b[32 * N + n];
0305 c[21 * N + n] = a[18 * N + n] * b[3 * N + n] + a[19 * N + n] * b[9 * N + n] + a[20 * N + n] * b[15 * N + n] +
0306 a[21 * N + n] * b[21 * N + n] + a[22 * N + n] * b[27 * N + n] + a[23 * N + n] * b[33 * N + n];
0307 c[22 * N + n] = a[18 * N + n] * b[4 * N + n] + a[19 * N + n] * b[10 * N + n] + a[20 * N + n] * b[16 * N + n] +
0308 a[21 * N + n] * b[22 * N + n] + a[22 * N + n] * b[28 * N + n] + a[23 * N + n] * b[34 * N + n];
0309 c[23 * N + n] = a[18 * N + n] * b[5 * N + n] + a[19 * N + n] * b[11 * N + n] + a[20 * N + n] * b[17 * N + n] +
0310 a[21 * N + n] * b[23 * N + n] + a[22 * N + n] * b[29 * N + n] + a[23 * N + n] * b[35 * N + n];
0311 c[24 * N + n] = a[24 * N + n] * b[0 * N + n] + a[25 * N + n] * b[6 * N + n] + a[26 * N + n] * b[12 * N + n] +
0312 a[27 * N + n] * b[18 * N + n] + a[28 * N + n] * b[24 * N + n] + a[29 * N + n] * b[30 * N + n];
0313 c[25 * N + n] = a[24 * N + n] * b[1 * N + n] + a[25 * N + n] * b[7 * N + n] + a[26 * N + n] * b[13 * N + n] +
0314 a[27 * N + n] * b[19 * N + n] + a[28 * N + n] * b[25 * N + n] + a[29 * N + n] * b[31 * N + n];
0315 c[26 * N + n] = a[24 * N + n] * b[2 * N + n] + a[25 * N + n] * b[8 * N + n] + a[26 * N + n] * b[14 * N + n] +
0316 a[27 * N + n] * b[20 * N + n] + a[28 * N + n] * b[26 * N + n] + a[29 * N + n] * b[32 * N + n];
0317 c[27 * N + n] = a[24 * N + n] * b[3 * N + n] + a[25 * N + n] * b[9 * N + n] + a[26 * N + n] * b[15 * N + n] +
0318 a[27 * N + n] * b[21 * N + n] + a[28 * N + n] * b[27 * N + n] + a[29 * N + n] * b[33 * N + n];
0319 c[28 * N + n] = a[24 * N + n] * b[4 * N + n] + a[25 * N + n] * b[10 * N + n] + a[26 * N + n] * b[16 * N + n] +
0320 a[27 * N + n] * b[22 * N + n] + a[28 * N + n] * b[28 * N + n] + a[29 * N + n] * b[34 * N + n];
0321 c[29 * N + n] = a[24 * N + n] * b[5 * N + n] + a[25 * N + n] * b[11 * N + n] + a[26 * N + n] * b[17 * N + n] +
0322 a[27 * N + n] * b[23 * N + n] + a[28 * N + n] * b[29 * N + n] + a[29 * N + n] * b[35 * N + n];
0323 c[30 * N + n] = a[30 * N + n] * b[0 * N + n] + a[31 * N + n] * b[6 * N + n] + a[32 * N + n] * b[12 * N + n] +
0324 a[33 * N + n] * b[18 * N + n] + a[34 * N + n] * b[24 * N + n] + a[35 * N + n] * b[30 * N + n];
0325 c[31 * N + n] = a[30 * N + n] * b[1 * N + n] + a[31 * N + n] * b[7 * N + n] + a[32 * N + n] * b[13 * N + n] +
0326 a[33 * N + n] * b[19 * N + n] + a[34 * N + n] * b[25 * N + n] + a[35 * N + n] * b[31 * N + n];
0327 c[32 * N + n] = a[30 * N + n] * b[2 * N + n] + a[31 * N + n] * b[8 * N + n] + a[32 * N + n] * b[14 * N + n] +
0328 a[33 * N + n] * b[20 * N + n] + a[34 * N + n] * b[26 * N + n] + a[35 * N + n] * b[32 * N + n];
0329 c[33 * N + n] = a[30 * N + n] * b[3 * N + n] + a[31 * N + n] * b[9 * N + n] + a[32 * N + n] * b[15 * N + n] +
0330 a[33 * N + n] * b[21 * N + n] + a[34 * N + n] * b[27 * N + n] + a[35 * N + n] * b[33 * N + n];
0331 c[34 * N + n] = a[30 * N + n] * b[4 * N + n] + a[31 * N + n] * b[10 * N + n] + a[32 * N + n] * b[16 * N + n] +
0332 a[33 * N + n] * b[22 * N + n] + a[34 * N + n] * b[28 * N + n] + a[35 * N + n] * b[34 * N + n];
0333 c[35 * N + n] = a[30 * N + n] * b[5 * N + n] + a[31 * N + n] * b[11 * N + n] + a[32 * N + n] * b[17 * N + n] +
0334 a[33 * N + n] * b[23 * N + n] + a[34 * N + n] * b[29 * N + n] + a[35 * N + n] * b[35 * N + n];
0335 }
0336 }
0337 };
0338
0339 template <typename T, idx_t D, idx_t N>
0340 void multiply(const MPlex<T, D, D, N>& A, const MPlex<T, D, D, N>& B, MPlex<T, D, D, N>& C) {
0341 #ifdef DEBUG
0342 printf("Multipl %d %d\n", D, N);
0343 #endif
0344
0345 MultiplyCls<T, D, N>::multiply(A, B, C);
0346 }
0347
0348
0349
0350
0351
0352 template <typename T, idx_t D, idx_t N>
0353 struct CramerInverter {
0354 static void invert(MPlex<T, D, D, N>& A, double* determ = nullptr) {
0355 throw std::runtime_error("general cramer inversion not supported");
0356 }
0357 };
0358
0359 template <typename T, idx_t N>
0360 struct CramerInverter<T, 2, N> {
0361 static void invert(MPlex<T, 2, 2, N>& A, double* determ = nullptr) {
0362 typedef T TT;
0363
0364 T* a = A.fArray;
0365 ASSUME_ALIGNED(a, 64);
0366
0367 #pragma omp simd
0368 for (idx_t n = 0; n < N; ++n) {
0369
0370 const double det = (double)a[0 * N + n] * a[3 * N + n] - (double)a[2 * N + n] * a[1 * N + n];
0371 if (determ)
0372 determ[n] = det;
0373
0374 const TT s = TT(1) / det;
0375 const TT tmp = s * a[3 * N + n];
0376 a[1 * N + n] *= -s;
0377 a[2 * N + n] *= -s;
0378 a[3 * N + n] = s * a[0 * N + n];
0379 a[0 * N + n] = tmp;
0380 }
0381 }
0382 };
0383
0384 template <typename T, idx_t N>
0385 struct CramerInverter<T, 3, N> {
0386 static void invert(MPlex<T, 3, 3, N>& A, double* determ = nullptr) {
0387 typedef T TT;
0388
0389 T* a = A.fArray;
0390 ASSUME_ALIGNED(a, 64);
0391
0392 #pragma omp simd
0393 for (idx_t n = 0; n < N; ++n) {
0394 const TT c00 = a[4 * N + n] * a[8 * N + n] - a[5 * N + n] * a[7 * N + n];
0395 const TT c01 = a[5 * N + n] * a[6 * N + n] - a[3 * N + n] * a[8 * N + n];
0396 const TT c02 = a[3 * N + n] * a[7 * N + n] - a[4 * N + n] * a[6 * N + n];
0397 const TT c10 = a[7 * N + n] * a[2 * N + n] - a[8 * N + n] * a[1 * N + n];
0398 const TT c11 = a[8 * N + n] * a[0 * N + n] - a[6 * N + n] * a[2 * N + n];
0399 const TT c12 = a[6 * N + n] * a[1 * N + n] - a[7 * N + n] * a[0 * N + n];
0400 const TT c20 = a[1 * N + n] * a[5 * N + n] - a[2 * N + n] * a[4 * N + n];
0401 const TT c21 = a[2 * N + n] * a[3 * N + n] - a[0 * N + n] * a[5 * N + n];
0402 const TT c22 = a[0 * N + n] * a[4 * N + n] - a[1 * N + n] * a[3 * N + n];
0403
0404
0405 const double det = (double)a[0 * N + n] * c00 + (double)a[1 * N + n] * c01 + (double)a[2 * N + n] * c02;
0406 if (determ)
0407 determ[n] = det;
0408
0409 const TT s = TT(1) / det;
0410 a[0 * N + n] = s * c00;
0411 a[1 * N + n] = s * c10;
0412 a[2 * N + n] = s * c20;
0413 a[3 * N + n] = s * c01;
0414 a[4 * N + n] = s * c11;
0415 a[5 * N + n] = s * c21;
0416 a[6 * N + n] = s * c02;
0417 a[7 * N + n] = s * c12;
0418 a[8 * N + n] = s * c22;
0419 }
0420 }
0421 };
0422
0423 template <typename T, idx_t D, idx_t N>
0424 void invertCramer(MPlex<T, D, D, N>& A, double* determ = nullptr) {
0425 CramerInverter<T, D, N>::invert(A, determ);
0426 }
0427
0428
0429
0430
0431
0432 template <typename T, idx_t D, idx_t N>
0433 struct CholeskyInverter {
0434 static void invert(MPlex<T, D, D, N>& A) { throw std::runtime_error("general cholesky inversion not supported"); }
0435 };
0436
0437 template <typename T, idx_t N>
0438 struct CholeskyInverter<T, 3, N> {
0439
0440
0441 static void invert(MPlex<T, 3, 3, N>& A) {
0442 typedef T TT;
0443
0444 T* a = A.fArray;
0445 ASSUME_ALIGNED(a, 64);
0446
0447 #pragma omp simd
0448 for (idx_t n = 0; n < N; ++n) {
0449 TT l0 = std::sqrt(T(1) / a[0 * N + n]);
0450 TT l1 = a[3 * N + n] * l0;
0451 TT l2 = a[4 * N + n] - l1 * l1;
0452 l2 = std::sqrt(T(1) / l2);
0453 TT l3 = a[6 * N + n] * l0;
0454 TT l4 = (a[7 * N + n] - l1 * l3) * l2;
0455 TT l5 = a[8 * N + n] - (l3 * l3 + l4 * l4);
0456 l5 = std::sqrt(T(1) / l5);
0457
0458
0459
0460 l3 = (l1 * l4 * l2 - l3) * l0 * l5;
0461 l1 = -l1 * l0 * l2;
0462 l4 = -l4 * l2 * l5;
0463
0464 a[0 * N + n] = l3 * l3 + l1 * l1 + l0 * l0;
0465 a[1 * N + n] = a[3 * N + n] = l3 * l4 + l1 * l2;
0466 a[4 * N + n] = l4 * l4 + l2 * l2;
0467 a[2 * N + n] = a[6 * N + n] = l3 * l5;
0468 a[5 * N + n] = a[7 * N + n] = l4 * l5;
0469 a[8 * N + n] = l5 * l5;
0470
0471
0472
0473 }
0474 }
0475 };
0476
0477 template <typename T, idx_t D, idx_t N>
0478 void invertCholesky(MPlex<T, D, D, N>& A) {
0479 CholeskyInverter<T, D, N>::invert(A);
0480 }
0481
0482 }
0483
0484 #endif