File indexing completed on 2024-04-06 12:28:40
0001 #ifndef RecoTracker_PixelVertexFinding_clusterTracksIterativeAlpaka_h
0002 #define RecoTracker_PixelVertexFinding_clusterTracksIterativeAlpaka_h
0003
0004 #include <algorithm>
0005 #include <cmath>
0006 #include <cstdint>
0007
0008 #include <alpaka/alpaka.hpp>
0009
0010 #include "DataFormats/VertexSoA/interface/ZVertexDefinitions.h"
0011 #include "HeterogeneousCore/AlpakaInterface/interface/HistoContainer.h"
0012 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0013 #include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
0014 #include "RecoTracker/PixelVertexFinding/interface/PixelVertexWorkSpaceLayout.h"
0015
0016 #include "vertexFinder.h"
0017
0018 namespace ALPAKA_ACCELERATOR_NAMESPACE {
0019 namespace vertexFinder {
0020
0021
0022
0023 class ClusterTracksIterative {
0024 public:
0025 template <typename TAcc>
0026 ALPAKA_FN_ACC void operator()(const TAcc& acc,
0027 VtxSoAView pdata,
0028 WsSoAView pws,
0029 int minT,
0030 float eps,
0031 float errmax,
0032 float chi2max
0033 ) const {
0034 constexpr bool verbose = false;
0035 const uint32_t threadIdxLocal(alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u]);
0036 if constexpr (verbose) {
0037 if (cms::alpakatools::once_per_block(acc))
0038 printf("params %d %f %f %f\n", minT, eps, errmax, chi2max);
0039 }
0040 auto er2mx = errmax * errmax;
0041
0042 auto& __restrict__ data = pdata;
0043 auto& __restrict__ ws = pws;
0044 auto nt = ws.ntrks();
0045 float const* __restrict__ zt = ws.zt();
0046 float const* __restrict__ ezt2 = ws.ezt2();
0047
0048 uint32_t& nvFinal = data.nvFinal();
0049 uint32_t& nvIntermediate = ws.nvIntermediate();
0050
0051 uint8_t* __restrict__ izt = ws.izt();
0052 int32_t* __restrict__ nn = data.ndof();
0053 int32_t* __restrict__ iv = ws.iv();
0054
0055 ALPAKA_ASSERT_ACC(zt);
0056 ALPAKA_ASSERT_ACC(nn);
0057 ALPAKA_ASSERT_ACC(iv);
0058 ALPAKA_ASSERT_ACC(ezt2);
0059
0060 using Hist = cms::alpakatools::HistoContainer<uint8_t, 256, 16000, 8, uint16_t>;
0061 auto& hist = alpaka::declareSharedVar<Hist, __COUNTER__>(acc);
0062 auto& hws = alpaka::declareSharedVar<Hist::Counter[32], __COUNTER__>(acc);
0063
0064 for (auto j : cms::alpakatools::uniform_elements(acc, Hist::totbins())) {
0065 hist.off[j] = 0;
0066 }
0067 alpaka::syncBlockThreads(acc);
0068
0069 if constexpr (verbose) {
0070 if (cms::alpakatools::once_per_block(acc))
0071 printf("booked hist with %d bins, size %d for %d tracks\n", hist.nbins(), hist.capacity(), nt);
0072 }
0073
0074 ALPAKA_ASSERT_ACC(static_cast<int>(nt) <= hist.capacity());
0075
0076
0077 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0078 ALPAKA_ASSERT_ACC(i < ::zVertex::MAXTRACKS);
0079 int iz = int(zt[i] * 10.);
0080 iz = std::clamp(iz, INT8_MIN, INT8_MAX);
0081 izt[i] = iz - INT8_MIN;
0082 ALPAKA_ASSERT_ACC(iz - INT8_MIN >= 0);
0083 ALPAKA_ASSERT_ACC(iz - INT8_MIN < 256);
0084 hist.count(acc, izt[i]);
0085 iv[i] = i;
0086 nn[i] = 0;
0087 }
0088 alpaka::syncBlockThreads(acc);
0089
0090 if (threadIdxLocal < 32)
0091 hws[threadIdxLocal] = 0;
0092 alpaka::syncBlockThreads(acc);
0093
0094 hist.finalize(acc, hws);
0095 alpaka::syncBlockThreads(acc);
0096
0097 ALPAKA_ASSERT_ACC(hist.size() == nt);
0098 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0099 hist.fill(acc, izt[i], uint16_t(i));
0100 }
0101 alpaka::syncBlockThreads(acc);
0102
0103
0104 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0105 if (ezt2[i] > er2mx)
0106 continue;
0107 auto loop = [&](uint32_t j) {
0108 if (i == j)
0109 return;
0110 auto dist = std::abs(zt[i] - zt[j]);
0111 if (dist > eps)
0112 return;
0113 if (dist * dist > chi2max * (ezt2[i] + ezt2[j]))
0114 return;
0115 nn[i]++;
0116 };
0117
0118 cms::alpakatools::forEachInBins(hist, izt[i], 1, loop);
0119 }
0120
0121 auto& nloops = alpaka::declareSharedVar<int, __COUNTER__>(acc);
0122 nloops = 0;
0123
0124 alpaka::syncBlockThreads(acc);
0125
0126
0127 bool more = true;
0128 while (alpaka::syncBlockThreadsPredicate<alpaka::BlockOr>(acc, more)) {
0129 if (1 == nloops % 2) {
0130 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0131 auto m = iv[i];
0132 while (m != iv[m])
0133 m = iv[m];
0134 iv[i] = m;
0135 }
0136 } else {
0137 more = false;
0138 for (auto k : cms::alpakatools::uniform_elements(acc, hist.size())) {
0139 auto p = hist.begin() + k;
0140 auto i = (*p);
0141 auto be = std::min(Hist::bin(izt[i]) + 1, int(hist.nbins() - 1));
0142 if (nn[i] < minT)
0143 continue;
0144 auto loop = [&](uint32_t j) {
0145 ALPAKA_ASSERT_ACC(i != j);
0146 if (nn[j] < minT)
0147 return;
0148 auto dist = std::abs(zt[i] - zt[j]);
0149 if (dist > eps)
0150 return;
0151 if (dist * dist > chi2max * (ezt2[i] + ezt2[j]))
0152 return;
0153 auto old = alpaka::atomicMin(acc, &iv[j], iv[i], alpaka::hierarchy::Blocks{});
0154 if (old != iv[i]) {
0155
0156 more = true;
0157 }
0158 alpaka::atomicMin(acc, &iv[i], old, alpaka::hierarchy::Blocks{});
0159 };
0160 ++p;
0161 for (; p < hist.end(be); ++p)
0162 loop(*p);
0163 }
0164 }
0165 if (threadIdxLocal == 0)
0166 ++nloops;
0167 }
0168
0169
0170 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0171
0172 if (nn[i] >= minT)
0173 continue;
0174 float mdist = eps;
0175 auto loop = [&](int j) {
0176 if (nn[j] < minT)
0177 return;
0178 auto dist = std::abs(zt[i] - zt[j]);
0179 if (dist > mdist)
0180 return;
0181 if (dist * dist > chi2max * (ezt2[i] + ezt2[j]))
0182 return;
0183 mdist = dist;
0184 iv[i] = iv[j];
0185 };
0186 cms::alpakatools::forEachInBins(hist, izt[i], 1, loop);
0187 }
0188
0189 auto& foundClusters = alpaka::declareSharedVar<unsigned int, __COUNTER__>(acc);
0190 foundClusters = 0;
0191 alpaka::syncBlockThreads(acc);
0192
0193
0194
0195 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0196 if (iv[i] == int(i)) {
0197 if (nn[i] >= minT) {
0198 auto old = alpaka::atomicInc(acc, &foundClusters, 0xffffffff, alpaka::hierarchy::Threads{});
0199 iv[i] = -(old + 1);
0200 } else {
0201 iv[i] = -9998;
0202 }
0203 }
0204 }
0205 alpaka::syncBlockThreads(acc);
0206
0207 ALPAKA_ASSERT_ACC(foundClusters < ::zVertex::MAXVTX);
0208
0209
0210 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0211 if (iv[i] >= 0) {
0212
0213 iv[i] = iv[iv[i]];
0214 }
0215 }
0216 alpaka::syncBlockThreads(acc);
0217
0218
0219 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0220 iv[i] = -iv[i] - 1;
0221 }
0222
0223 nvIntermediate = nvFinal = foundClusters;
0224
0225 if constexpr (verbose) {
0226 if (cms::alpakatools::once_per_block(acc))
0227 printf("found %d proto vertices\n", foundClusters);
0228 }
0229 }
0230 };
0231 }
0232 }
0233 #endif