File indexing completed on 2023-03-17 11:22:39
0001 #include "Math/SMatrix.h"
0002
0003 #include "MatriplexSym.h"
0004
0005 #include <random>
0006
0007
0008
0009
0010
0011
0012
0013
0014 typedef long long long64;
0015
0016 const int N = 16;
0017
0018 const int DIM = 3;
0019 const int DOM = 6;
0020
0021 #ifdef MPLEX_INTRINSICS
0022 # if defined(__AVX512F__)
0023 # warning "MPLEX_INTRINSICS CMP_EPS = 2e-7 --> 3e-7"
0024 const float CMP_EPS = 3e-7;
0025 # elif defined(__AVX__)
0026 # warning "MPLEX_INTRINSICS CMP_EPS = 2e-7 --> 5e-7"
0027 const float CMP_EPS = 5e-7;
0028 # else
0029 # warning "MPLEX_INTRINSICS CMP_EPS = 2e-7"
0030 const float CMP_EPS = 2e-7;
0031 # endif
0032 #else
0033 # if defined(__AVX512F__)
0034 # warning "NO MPLEX_INTRINSICS CMP_EPS = 4e-7"
0035 const float CMP_EPS = 4e-7;
0036 # else
0037 # warning "NO MPLEX_INTRINSICS CMP_EPS = 4e-7 --> 5e-7"
0038 const float CMP_EPS = 5e-7;
0039 # endif
0040 #endif
0041
0042 typedef ROOT::Math::SMatrix<float, DIM, DOM> SMatX;
0043 typedef ROOT::Math::SMatrix<float, DOM, DIM> SMatXT;
0044 typedef ROOT::Math::SMatrix<float, DIM, DIM, ROOT::Math::MatRepSym<float, DIM> > SMatS;
0045
0046 typedef Matriplex::Matriplex <float, DIM, DOM, N> MPlexX;
0047 typedef Matriplex::Matriplex <float, DOM, DIM, N> MPlexXT;
0048 typedef Matriplex::MatriplexSym<float, DIM, N> MPlexS;
0049
0050 void Multify(const MPlexS& A, const MPlexX& B, MPlexX& C)
0051 {
0052
0053
0054 typedef float T;
0055
0056 const T *a = A.fArray; __assume_aligned(a, 64);
0057 const T *b = B.fArray; __assume_aligned(b, 64);
0058 T *c = C.fArray; __assume_aligned(c, 64);
0059
0060 #include "multify.ah"
0061 }
0062
0063 void MultifyTranspose(const MPlexS& A, const MPlexX& B, MPlexXT& C)
0064 {
0065
0066
0067 typedef float T;
0068
0069 const T *a = A.fArray; __assume_aligned(a, 64);
0070 const T *b = B.fArray; __assume_aligned(b, 64);
0071 T *c = C.fArray; __assume_aligned(c, 64);
0072
0073 #include "multify-transpose.ah"
0074 }
0075
0076 int main()
0077 {
0078 SMatS a[N];
0079 SMatX b[N], c[N];
0080 SMatXT bt[N], ct[N];
0081
0082 MPlexS A;
0083 MPlexX B, C;
0084 MPlexXT CT;
0085
0086 std::default_random_engine gen(0xbeef0133);
0087 std::normal_distribution<float> dis(1.0, 0.05);
0088
0089 long64 count = 1;
0090
0091 init:
0092
0093 for (int m = 0; m < N; ++m)
0094 {
0095 for (int i = 0; i < 3; ++i)
0096 {
0097 for (int j = i; j < 6; ++j)
0098 {
0099 if (j < DIM) a[m](i,j) = dis(gen);
0100
0101 b[m](i,j) = dis(gen);
0102 }
0103 }
0104
0105
0106 a[m](1, 1) = 1;
0107 b[m](0, 4) = 0;
0108 b[m](1, 1) = 1;
0109 b[m](1, 3) = 1;
0110 b[m](1, 4) = 0;
0111 b[m](2, 4) = 0;
0112
0113 A.CopyIn(m, a[m].Array());
0114 B.CopyIn(m, b[m].Array());
0115
0116 c[m] = a[m] * b[m];
0117
0118 bt[m] = ROOT::Math::Transpose(b[m]);
0119 ct[m] = bt[m] * a[m];
0120 }
0121
0122 Multify(A, B, C);
0123 MultifyTranspose(A, B, CT);
0124
0125 for (int m = 0; m < N; ++m)
0126 {
0127 bool dump = false;
0128
0129 for (int j = 0; j < DIM; ++j)
0130 {
0131 for (int k = 0; k < DOM; ++k)
0132 {
0133
0134
0135
0136
0137
0138
0139 if (std::abs(c[m](j,k) - C.At(m, j, k)) > CMP_EPS)
0140 {
0141 dump = true;
0142 printf("MULTIFY M=%d %d,%d d=%e (count = %lld)\n", m, j, k, c[m](j,k) - C.At(m, j, k), count);
0143 }
0144 }
0145 }
0146
0147 if (dump && false)
0148 {
0149 printf("\n");
0150 for (int i = 0; i < DIM; ++i)
0151 {
0152 for (int j = 0; j < DOM; ++j)
0153 printf("%8f ", c[m](i,j));
0154 printf("\n");
0155 }
0156 printf("\n");
0157
0158 for (int i = 0; i < DIM; ++i)
0159 {
0160 for (int j = 0; j < DOM; ++j)
0161 printf("%8f ", C.At(m, i, j));
0162 printf("\n");
0163 }
0164 printf("\n");
0165 }
0166 if (dump)
0167 {
0168 printf("\n");
0169 }
0170 }
0171
0172
0173
0174
0175 for (int m = 0; m < N; ++m)
0176 {
0177 bool dump = false;
0178
0179 for (int j = 0; j < DOM; ++j)
0180 {
0181 for (int k = 0; k < DIM; ++k)
0182 {
0183
0184
0185
0186
0187
0188
0189 if (std::abs(ct[m](j,k) - CT.At(m, j, k)) > CMP_EPS)
0190 {
0191 dump = true;
0192 printf("TRANSPOSE M=%d %d,%d d=%e (count = %lld)\n", m, j, k, ct[m](j,k) - CT.At(m, j, k), count);
0193 }
0194 }
0195 }
0196
0197 if (dump && false)
0198 {
0199 printf("\n");
0200 for (int i = 0; i < DOM; ++i)
0201 {
0202 for (int j = 0; j < DIM; ++j)
0203 printf("%8f ", ct[m](i,j));
0204 printf("\n");
0205 }
0206 printf("\n");
0207
0208 for (int i = 0; i < DIM; ++i)
0209 {
0210 for (int j = 0; j < DOM; ++j)
0211 printf("%8f ", CT.At(m, i, j));
0212 printf("\n");
0213 }
0214 printf("\n");
0215 }
0216 if (dump)
0217 {
0218 printf("\n");
0219 }
0220 }
0221
0222
0223 ++count;
0224 goto init;
0225
0226 return 0;
0227 }