Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:28:24

0001 #include "Math/SMatrix.h"
0002 
0003 #include "MatriplexSym.h"
0004 
0005 #include <random>
0006 
0007 /*
0008 # Generate .ah files (make sure DIM, DOM and pattern match):
0009   ./GMtest.pl
0010 # Compile:
0011   icc -std=gnu++11 -openmp -mavx -O3 -I.. -I../.. GMtest.cxx -o GMtest
0012 */
0013 
0014 typedef long long long64;
0015 
0016 const int N   = 16;
0017 
0018 const int DIM =  3;
0019 const int DOM =  6;
0020 
0021 #ifdef MPLEX_INTRINSICS
0022 # if defined(__AVX512F__)
0023 #   warning "MPLEX_INTRINSICS CMP_EPS = 2e-7 --> 3e-7"
0024 const float CMP_EPS = 3e-7;
0025 # elif defined(__AVX__)
0026 #   warning "MPLEX_INTRINSICS CMP_EPS = 2e-7 --> 5e-7"
0027 const float CMP_EPS = 5e-7;
0028 # else
0029 #   warning "MPLEX_INTRINSICS CMP_EPS = 2e-7"
0030 const float CMP_EPS = 2e-7;
0031 # endif
0032 #else
0033 # if defined(__AVX512F__)
0034 #   warning "NO MPLEX_INTRINSICS CMP_EPS = 4e-7"
0035 const float CMP_EPS = 4e-7;
0036 # else
0037 #   warning "NO MPLEX_INTRINSICS CMP_EPS = 4e-7 --> 5e-7"
0038 const float CMP_EPS = 5e-7;
0039 # endif
0040 #endif
0041 
0042 typedef ROOT::Math::SMatrix<float, DIM, DOM>                                     SMatX;
0043 typedef ROOT::Math::SMatrix<float, DOM, DIM>                                     SMatXT;
0044 typedef ROOT::Math::SMatrix<float, DIM, DIM, ROOT::Math::MatRepSym<float, DIM> > SMatS;
0045 
0046 typedef Matriplex::Matriplex   <float, DIM, DOM, N>   MPlexX;
0047 typedef Matriplex::Matriplex   <float, DOM, DIM, N>   MPlexXT;
0048 typedef Matriplex::MatriplexSym<float, DIM,      N>   MPlexS;
0049 
0050 void Multify(const MPlexS& A, const MPlexX& B, MPlexX& C)
0051 {
0052    // C = A * B
0053 
0054    typedef float T;
0055 
0056    const T *a = A.fArray; __assume_aligned(a, 64);
0057    const T *b = B.fArray; __assume_aligned(b, 64);
0058          T *c = C.fArray; __assume_aligned(c, 64);
0059 
0060 #include "multify.ah"
0061 }
0062 
0063 void MultifyTranspose(const MPlexS& A, const MPlexX& B, MPlexXT& C)
0064 {
0065    // C = BT * A;
0066 
0067    typedef float T;
0068 
0069    const T *a = A.fArray; __assume_aligned(a, 64);
0070    const T *b = B.fArray; __assume_aligned(b, 64);
0071          T *c = C.fArray; __assume_aligned(c, 64);
0072 
0073 #include "multify-transpose.ah"
0074 }
0075 
0076 int main()
0077 {
0078   SMatS   a[N];
0079   SMatX   b[N],  c[N];
0080   SMatXT  bt[N], ct[N];
0081 
0082   MPlexS  A;
0083   MPlexX  B, C;
0084   MPlexXT CT;
0085 
0086   std::default_random_engine      gen(0xbeef0133);
0087   std::normal_distribution<float> dis(1.0, 0.05);
0088 
0089   long64 count = 1;
0090 
0091 init:
0092 
0093   for (int m = 0; m < N; ++m)
0094   {
0095     for (int i = 0; i < 3; ++i)
0096     {
0097       for (int j = i; j < 6; ++j)
0098       {
0099         if (j < DIM)  a[m](i,j) = dis(gen);
0100 
0101         b[m](i,j) = dis(gen);
0102       }
0103     }
0104 
0105     // Enforce pattern from GMtest.pl
0106     a[m](1, 1) = 1;
0107     b[m](0, 4) = 0;
0108     b[m](1, 1) = 1;
0109     b[m](1, 3) = 1;
0110     b[m](1, 4) = 0;
0111     b[m](2, 4) = 0;
0112 
0113     A.CopyIn(m, a[m].Array());
0114     B.CopyIn(m, b[m].Array());
0115 
0116     c[m]  = a[m] * b[m];
0117 
0118     bt[m] = ROOT::Math::Transpose(b[m]);
0119     ct[m] = bt[m] * a[m];
0120   }
0121 
0122   Multify(A, B, C);
0123   MultifyTranspose(A, B, CT);
0124 
0125   for (int m = 0; m < N; ++m)
0126   {
0127     bool dump = false;
0128 
0129     for (int j = 0; j < DIM; ++j)
0130     {
0131       for (int k = 0; k < DOM; ++k)
0132       {
0133         // There are occasional diffs up to 4.768372e-07 on host, very very
0134         // rarely on MIC. Apparently this is a rounding difference between AVX
0135         // and normal maths. On MIC it might be usage of FMA?
0136         // The above was for 3x3.
0137         // For 6x6 practically all elements differ by 4.768372e-07, some
0138         // by 9.536743e-07.
0139         if (std::abs(c[m](j,k) - C.At(m, j, k)) > CMP_EPS)
0140         {
0141           dump = true;
0142           printf("MULTIFY   M=%d  %d,%d d=%e (count = %lld)\n", m, j, k, c[m](j,k) - C.At(m, j, k), count);
0143         }
0144       }
0145     }
0146 
0147     if (dump && false)
0148     {
0149       printf("\n");
0150       for (int i = 0; i < DIM; ++i)
0151       {
0152         for (int j = 0; j < DOM; ++j)
0153           printf("%8f ", c[m](i,j));
0154         printf("\n");
0155       }
0156       printf("\n");
0157 
0158       for (int i = 0; i < DIM; ++i)
0159       {
0160         for (int j = 0; j < DOM; ++j)
0161           printf("%8f ", C.At(m, i, j));
0162         printf("\n");
0163       }
0164       printf("\n");
0165     }
0166     if (dump)
0167     {
0168       printf("\n");
0169     }
0170   }
0171 
0172   // Shameless cut-n-paste of above dump for transpose check with minor changes.
0173   // Should make a function, I know ... but ... no time to lose.
0174 
0175   for (int m = 0; m < N; ++m)
0176   {
0177     bool dump = false;
0178 
0179     for (int j = 0; j < DOM; ++j)
0180     {
0181       for (int k = 0; k < DIM; ++k)
0182       {
0183         // There are occasional diffs up to 4.768372e-07 on host, very very
0184         // rarely on MIC. Apparently this is a rounding difference between AVX
0185         // and normal maths. On MIC it might be usage of FMA?
0186         // The above was for 3x3.
0187         // For 6x6 practically all elements differ by 4.768372e-07, some
0188         // by 9.536743e-07.
0189         if (std::abs(ct[m](j,k) - CT.At(m, j, k)) > CMP_EPS)
0190         {
0191           dump = true;
0192           printf("TRANSPOSE M=%d  %d,%d d=%e (count = %lld)\n", m, j, k, ct[m](j,k) - CT.At(m, j, k), count);
0193         }
0194       }
0195     }
0196 
0197     if (dump && false)
0198     {
0199       printf("\n");
0200       for (int i = 0; i < DOM; ++i)
0201       {
0202         for (int j = 0; j < DIM; ++j)
0203           printf("%8f ", ct[m](i,j));
0204         printf("\n");
0205       }
0206       printf("\n");
0207 
0208       for (int i = 0; i < DIM; ++i)
0209       {
0210         for (int j = 0; j < DOM; ++j)
0211           printf("%8f ", CT.At(m, i, j));
0212         printf("\n");
0213       }
0214       printf("\n");
0215     }
0216     if (dump)
0217     {
0218       printf("\n");
0219     }
0220   }
0221 
0222 
0223   ++count;
0224   goto init;
0225 
0226   return 0;
0227 }