Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:04:43

0001 // nvcc -O3 CholeskyDecomp_t.cu --expt-relaxed-constexpr -gencode arch=compute_61,code=sm_61 --compiler-options="-Ofast -march=native"
0002 // add -DDOPROF to run  nvprof --metrics all
0003 
0004 #include <algorithm>
0005 #include <cassert>
0006 #include <chrono>
0007 #include <iomanip>
0008 #include <iostream>
0009 #include <limits>
0010 #include <memory>
0011 #include <random>
0012 
0013 #include <Eigen/Core>
0014 #include <Eigen/Eigenvalues>
0015 
0016 #include "DataFormats/Math/interface/choleskyInversion.h"
0017 
0018 constexpr int stride() { return 5 * 1024; }
0019 template <int DIM>
0020 using MXN = Eigen::Matrix<double, DIM, DIM>;
0021 template <int DIM>
0022 using MapMX = Eigen::Map<MXN<DIM>, 0, Eigen::Stride<DIM * stride(), stride()> >;
0023 
0024 // generate matrices
0025 template <class M>
0026 void genMatrix(M& m) {
0027   using T = typename std::remove_reference<decltype(m(0, 0))>::type;
0028   int n = M::ColsAtCompileTime;
0029   std::mt19937 eng;
0030   // std::mt19937 eng2;
0031   std::uniform_real_distribution<T> rgen(0., 1.);
0032 
0033   // generate first diagonal elemets
0034   for (int i = 0; i < n; ++i) {
0035     double maxVal = i * 10000 / (n - 1) + 1;  // max condition is 10^4
0036     m(i, i) = maxVal * rgen(eng);
0037   }
0038   for (int i = 0; i < n; ++i) {
0039     for (int j = 0; j < i; ++j) {
0040       double v = 0.3 * std::sqrt(m(i, i) * m(j, j));  // this makes the matrix pos defined
0041       m(i, j) = v * rgen(eng);
0042       m(j, i) = m(i, j);
0043     }
0044   }
0045 }
0046 
0047 template <int N>
0048 void go(bool soa) {
0049   constexpr unsigned int DIM = N;
0050   using MX = MXN<DIM>;
0051   std::cout << "testing Matrix of dimension " << DIM << " size " << sizeof(MX) << " in " << (soa ? "SOA" : "AOS")
0052             << " mode" << std::endl;
0053 
0054   auto start = std::chrono::high_resolution_clock::now();
0055   auto delta = start - start;
0056 
0057   constexpr unsigned int SIZE = 4 * 1024;
0058 
0059   alignas(128) MX mm[stride()];  // just storage in case of SOA
0060   double* __restrict__ p = (double*)__builtin_assume_aligned(mm, 128);
0061 
0062   if (soa) {
0063     for (unsigned int i = 0; i < SIZE; ++i) {
0064       MapMX<N> m(p + i);
0065       genMatrix(m);
0066     }
0067   } else {
0068     for (auto& m : mm)
0069       genMatrix(m);
0070   }
0071 
0072   std::cout << mm[SIZE / 2](1, 1) << std::endl;
0073 
0074   if (soa)
0075     for (unsigned int i = 0; i < SIZE; ++i) {
0076       MapMX<N> m(p + i);
0077       math::cholesky::invert(m, m);
0078       math::cholesky::invert(m, m);
0079     }
0080   else
0081     for (auto& m : mm) {
0082       math::cholesky::invert(m, m);
0083       math::cholesky::invert(m, m);
0084     }
0085 
0086   std::cout << mm[SIZE / 2](1, 1) << std::endl;
0087 
0088   constexpr int NKK =
0089 #ifdef DOPROF
0090       2;
0091 #else
0092       1000;
0093 #endif
0094   for (int kk = 0; kk < NKK; ++kk) {
0095     delta -= (std::chrono::high_resolution_clock::now() - start);
0096     if (soa)
0097 #ifdef USE_VECTORIZATION_PRAGMA
0098 #pragma GCC ivdep
0099 #ifdef __clang__
0100 #pragma clang loop vectorize(enable) interleave(enable)
0101 #endif
0102 #endif
0103       for (unsigned int i = 0; i < SIZE; ++i) {
0104         MapMX<N> m(p + i);
0105         math::cholesky::invert(m, m);
0106       }
0107     else
0108 #ifdef USE_VECTORIZATION_PRAGMA
0109 #pragma GCC ivdep
0110 #ifdef __clang__
0111 #pragma clang loop vectorize(enable) interleave(enable)
0112 #endif
0113 #endif
0114       for (auto& m : mm) {
0115         math::cholesky::invert(m, m);
0116       }
0117 
0118     delta += (std::chrono::high_resolution_clock::now() - start);
0119   }
0120 
0121   std::cout << mm[SIZE / 2](1, 1) << std::endl;
0122 
0123   double DNNK = NKK;
0124   std::cout << "x86 computation took " << std::chrono::duration_cast<std::chrono::milliseconds>(delta).count() / DNNK
0125             << ' ' << " ms" << std::endl;
0126 }
0127 
0128 int main() {
0129   go<2>(false);
0130   go<3>(false);
0131   go<4>(false);
0132   go<5>(false);
0133   go<6>(false);
0134 
0135   go<2>(true);
0136   go<3>(true);
0137   go<4>(true);
0138   go<5>(true);
0139   go<6>(true);
0140 
0141   go<10>(false);
0142   go<10>(true);
0143   return 0;
0144 }