File indexing completed on 2024-09-13 22:52:46
0001 #ifndef RecoVertex_PixelVertexFinding_plugins_alpaka_fitVertices_h
0002 #define RecoVertex_PixelVertexFinding_plugins_alpaka_fitVertices_h
0003
0004 #include <algorithm>
0005 #include <cmath>
0006 #include <cstdint>
0007
0008 #include <alpaka/alpaka.hpp>
0009
0010 #include "HeterogeneousCore/AlpakaInterface/interface/HistoContainer.h"
0011 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0012 #include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
0013
0014 #include "vertexFinder.h"
0015
0016 namespace ALPAKA_ACCELERATOR_NAMESPACE::vertexFinder {
0017
0018 ALPAKA_FN_ACC ALPAKA_FN_INLINE __attribute__((always_inline)) void fitVertices(Acc1D const& acc,
0019 VtxSoAView& pdata,
0020 TrkSoAView& ptrkdata,
0021 WsSoAView& pws,
0022 float chi2Max
0023 ) {
0024 constexpr bool verbose = false;
0025
0026 auto& __restrict__ data = pdata;
0027 auto& __restrict__ trkdata = ptrkdata;
0028 auto& __restrict__ ws = pws;
0029 auto nt = ws.ntrks();
0030 float const* __restrict__ zt = ws.zt();
0031 float const* __restrict__ ezt2 = ws.ezt2();
0032 float* __restrict__ zv = data.zv();
0033 float* __restrict__ wv = data.wv();
0034 float* __restrict__ chi2 = data.chi2();
0035 uint32_t& nvFinal = data.nvFinal();
0036 uint32_t& nvIntermediate = ws.nvIntermediate();
0037
0038 int32_t* __restrict__ nn = trkdata.ndof();
0039 int32_t* __restrict__ iv = ws.iv();
0040
0041 ALPAKA_ASSERT_ACC(nvFinal <= nvIntermediate);
0042 nvFinal = nvIntermediate;
0043 auto foundClusters = nvFinal;
0044
0045
0046 for (auto i : cms::alpakatools::uniform_elements(acc, foundClusters)) {
0047 zv[i] = 0;
0048 wv[i] = 0;
0049 chi2[i] = 0;
0050 }
0051
0052
0053 auto& noise = alpaka::declareSharedVar<int, __COUNTER__>(acc);
0054
0055 if constexpr (verbose) {
0056 if (cms::alpakatools::once_per_block(acc))
0057 noise = 0;
0058 }
0059 alpaka::syncBlockThreads(acc);
0060
0061
0062 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0063 if (iv[i] > 9990) {
0064 if constexpr (verbose)
0065 alpaka::atomicAdd(acc, &noise, 1, alpaka::hierarchy::Threads{});
0066 continue;
0067 }
0068 ALPAKA_ASSERT_ACC(iv[i] >= 0);
0069 ALPAKA_ASSERT_ACC(iv[i] < int(foundClusters));
0070 auto w = 1.f / ezt2[i];
0071 alpaka::atomicAdd(acc, &zv[iv[i]], zt[i] * w, alpaka::hierarchy::Threads{});
0072 alpaka::atomicAdd(acc, &wv[iv[i]], w, alpaka::hierarchy::Threads{});
0073 }
0074
0075 alpaka::syncBlockThreads(acc);
0076
0077 for (auto i : cms::alpakatools::uniform_elements(acc, foundClusters)) {
0078 bool const wv_cond = (wv[i] > 0.f);
0079 if (not wv_cond) {
0080 printf("ERROR: wv[%d] (%f) > 0.f failed\n", i, wv[i]);
0081
0082 for (auto trk_i = 0u; trk_i < nt; ++trk_i) {
0083 if (iv[trk_i] != int(i)) {
0084 continue;
0085 }
0086 printf(" iv[%d]=%d zt[%d]=%f ezt2[%d]=%f\n", trk_i, iv[trk_i], trk_i, zt[trk_i], trk_i, ezt2[trk_i]);
0087 }
0088 ALPAKA_ASSERT_ACC(false);
0089 }
0090
0091 zv[i] /= wv[i];
0092 nn[i] = -1;
0093 }
0094 alpaka::syncBlockThreads(acc);
0095
0096
0097 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0098 if (iv[i] > 9990)
0099 continue;
0100
0101 auto c2 = zv[iv[i]] - zt[i];
0102 c2 *= c2 / ezt2[i];
0103 if (c2 > chi2Max) {
0104 iv[i] = 9999;
0105 continue;
0106 }
0107 alpaka::atomicAdd(acc, &chi2[iv[i]], c2, alpaka::hierarchy::Blocks{});
0108 alpaka::atomicAdd(acc, &nn[iv[i]], 1, alpaka::hierarchy::Blocks{});
0109 }
0110 alpaka::syncBlockThreads(acc);
0111
0112 for (auto i : cms::alpakatools::uniform_elements(acc, foundClusters)) {
0113 if (nn[i] > 0) {
0114 wv[i] *= float(nn[i]) / chi2[i];
0115 }
0116 }
0117 if constexpr (verbose) {
0118 if (cms::alpakatools::once_per_block(acc)) {
0119 printf("found %d proto clusters ", foundClusters);
0120 printf("and %d noise\n", noise);
0121 }
0122 }
0123 }
0124
0125 class FitVerticesKernel {
0126 public:
0127 ALPAKA_FN_ACC void operator()(Acc1D const& acc,
0128 VtxSoAView pdata,
0129 TrkSoAView ptrkdata,
0130 WsSoAView pws,
0131 float chi2Max
0132 ) const {
0133 fitVertices(acc, pdata, ptrkdata, pws, chi2Max);
0134 }
0135 };
0136
0137 }
0138
0139 #endif