Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:28:40

0001 #ifndef RecoTracker_PixelVertexFinding_plugins_alpaka_splitVertices_h
0002 #define RecoTracker_PixelVertexFinding_plugins_alpaka_splitVertices_h
0003 
0004 #include <algorithm>
0005 #include <cmath>
0006 #include <cstdint>
0007 
0008 #include <alpaka/alpaka.hpp>
0009 
0010 #include "HeterogeneousCore/AlpakaInterface/interface/HistoContainer.h"
0011 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0012 #include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
0013 
0014 #include "vertexFinder.h"
0015 
0016 namespace ALPAKA_ACCELERATOR_NAMESPACE::vertexFinder {
0017 
0018   using VtxSoAView = ::reco::ZVertexSoAView;
0019   using WsSoAView = ::vertexFinder::PixelVertexWorkSpaceSoAView;
0020   template <typename TAcc>
0021   ALPAKA_FN_ACC ALPAKA_FN_INLINE __attribute__((always_inline)) void splitVertices(const TAcc& acc,
0022                                                                                    VtxSoAView& pdata,
0023                                                                                    WsSoAView& pws,
0024                                                                                    float maxChi2) {
0025     constexpr bool verbose = false;  // in principle the compiler should optmize out if false
0026     const uint32_t threadIdxLocal(alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u]);
0027 
0028     auto& __restrict__ data = pdata;
0029     auto& __restrict__ ws = pws;
0030     auto nt = ws.ntrks();
0031     float const* __restrict__ zt = ws.zt();
0032     float const* __restrict__ ezt2 = ws.ezt2();
0033     float* __restrict__ zv = data.zv();
0034     float* __restrict__ wv = data.wv();
0035     float const* __restrict__ chi2 = data.chi2();
0036     uint32_t& nvFinal = data.nvFinal();
0037 
0038     int32_t const* __restrict__ nn = data.ndof();
0039     int32_t* __restrict__ iv = ws.iv();
0040 
0041     ALPAKA_ASSERT_ACC(zt);
0042     ALPAKA_ASSERT_ACC(wv);
0043     ALPAKA_ASSERT_ACC(chi2);
0044     ALPAKA_ASSERT_ACC(nn);
0045 
0046     constexpr uint32_t MAXTK = 512;
0047 
0048     auto& it = alpaka::declareSharedVar<uint32_t[MAXTK], __COUNTER__>(acc);   // track index
0049     auto& zz = alpaka::declareSharedVar<float[MAXTK], __COUNTER__>(acc);      // z pos
0050     auto& newV = alpaka::declareSharedVar<uint8_t[MAXTK], __COUNTER__>(acc);  // 0 or 1
0051     auto& ww = alpaka::declareSharedVar<float[MAXTK], __COUNTER__>(acc);      // z weight
0052     auto& nq = alpaka::declareSharedVar<uint32_t, __COUNTER__>(acc);          // number of track for this vertex
0053 
0054     const uint32_t blockIdx(alpaka::getIdx<alpaka::Grid, alpaka::Blocks>(acc)[0u]);
0055     const uint32_t gridDimension(alpaka::getWorkDiv<alpaka::Grid, alpaka::Blocks>(acc)[0u]);
0056 
0057     // one vertex per block
0058     for (auto kv = blockIdx; kv < nvFinal; kv += gridDimension) {
0059       if (nn[kv] < 4)
0060         continue;
0061       if (chi2[kv] < maxChi2 * float(nn[kv]))
0062         continue;
0063 
0064       ALPAKA_ASSERT_ACC(nn[kv] < int32_t(MAXTK));
0065 
0066       if ((uint32_t)nn[kv] >= MAXTK)
0067         continue;  // too bad FIXME
0068 
0069       nq = 0u;
0070       alpaka::syncBlockThreads(acc);
0071 
0072       // copy to local
0073       for (auto k : cms::alpakatools::independent_group_elements(acc, nt)) {
0074         if (iv[k] == int(kv)) {
0075           auto old = alpaka::atomicInc(acc, &nq, MAXTK, alpaka::hierarchy::Threads{});
0076           zz[old] = zt[k] - zv[kv];
0077           newV[old] = zz[old] < 0 ? 0 : 1;
0078           ww[old] = 1.f / ezt2[k];
0079           it[old] = k;
0080         }
0081       }
0082 
0083       // the new vertices
0084       auto& znew = alpaka::declareSharedVar<float[2], __COUNTER__>(acc);
0085       auto& wnew = alpaka::declareSharedVar<float[2], __COUNTER__>(acc);
0086       alpaka::syncBlockThreads(acc);
0087 
0088       ALPAKA_ASSERT_ACC(int(nq) == nn[kv] + 1);
0089 
0090       int maxiter = 20;
0091       // kt-min....
0092       bool more = true;
0093       while (alpaka::syncBlockThreadsPredicate<alpaka::BlockOr>(acc, more)) {
0094         more = false;
0095         if (0 == threadIdxLocal) {
0096           znew[0] = 0;
0097           znew[1] = 0;
0098           wnew[0] = 0;
0099           wnew[1] = 0;
0100         }
0101         alpaka::syncBlockThreads(acc);
0102 
0103         for (auto k : cms::alpakatools::uniform_elements(acc, nq)) {
0104           auto i = newV[k];
0105           alpaka::atomicAdd(acc, &znew[i], zz[k] * ww[k], alpaka::hierarchy::Threads{});
0106           alpaka::atomicAdd(acc, &wnew[i], ww[k], alpaka::hierarchy::Threads{});
0107         }
0108         alpaka::syncBlockThreads(acc);
0109 
0110         if (0 == threadIdxLocal) {
0111           znew[0] /= wnew[0];
0112           znew[1] /= wnew[1];
0113         }
0114         alpaka::syncBlockThreads(acc);
0115 
0116         for (auto k : cms::alpakatools::uniform_elements(acc, nq)) {
0117           auto d0 = fabs(zz[k] - znew[0]);
0118           auto d1 = fabs(zz[k] - znew[1]);
0119           auto newer = d0 < d1 ? 0 : 1;
0120           more |= newer != newV[k];
0121           newV[k] = newer;
0122         }
0123         --maxiter;
0124         if (maxiter <= 0)
0125           more = false;
0126       }
0127 
0128       // avoid empty vertices
0129       if (0 == wnew[0] || 0 == wnew[1])
0130         continue;
0131 
0132       // quality cut
0133       auto dist2 = (znew[0] - znew[1]) * (znew[0] - znew[1]);
0134 
0135       auto chi2Dist = dist2 / (1.f / wnew[0] + 1.f / wnew[1]);
0136 
0137       if (verbose && 0 == threadIdxLocal)
0138         printf("inter %d %f %f\n", 20 - maxiter, chi2Dist, dist2 * wv[kv]);
0139 
0140       if (chi2Dist < 4)
0141         continue;
0142 
0143       // get a new global vertex
0144       auto& igv = alpaka::declareSharedVar<uint32_t, __COUNTER__>(acc);
0145       if (0 == threadIdxLocal)
0146         igv = alpaka::atomicAdd(acc, &ws.nvIntermediate(), 1u, alpaka::hierarchy::Blocks{});
0147       alpaka::syncBlockThreads(acc);
0148       for (auto k : cms::alpakatools::uniform_elements(acc, nq)) {
0149         if (1 == newV[k])
0150           iv[it[k]] = igv;
0151       }
0152 
0153     }  // loop on vertices
0154   }
0155 
0156   class SplitVerticesKernel {
0157   public:
0158     template <typename TAcc>
0159     ALPAKA_FN_ACC void operator()(const TAcc& acc, VtxSoAView pdata, WsSoAView pws, float maxChi2) const {
0160       splitVertices(acc, pdata, pws, maxChi2);
0161     }
0162   };
0163 
0164 }  // namespace ALPAKA_ACCELERATOR_NAMESPACE::vertexFinder
0165 
0166 #endif  // RecoTracker_PixelVertexFinding_plugins_alpaka_splitVertices_h