File indexing completed on 2024-09-13 22:52:46
0001 #ifndef RecoVertex_PixelVertexFinding_plugins_alpaka_sortByPt2_h
0002 #define RecoVertex_PixelVertexFinding_plugins_alpaka_sortByPt2_h
0003
0004 #include <algorithm>
0005 #include <array>
0006 #include <cmath>
0007 #include <cstdint>
0008
0009 #include <alpaka/alpaka.hpp>
0010
0011 #include "DataFormats/VertexSoA/interface/ZVertexSoA.h"
0012 #include "HeterogeneousCore/AlpakaInterface/interface/HistoContainer.h"
0013 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0014 #include "HeterogeneousCore/AlpakaInterface/interface/radixSort.h"
0015 #include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
0016 #include "RecoVertex/PixelVertexFinding/interface/PixelVertexWorkSpaceLayout.h"
0017
0018 #include "vertexFinder.h"
0019
0020 namespace ALPAKA_ACCELERATOR_NAMESPACE::vertexFinder {
0021
0022 using VtxSoAView = ::reco::ZVertexSoAView;
0023 using TrkSoAView = ::reco::ZVertexTracksSoAView;
0024 using WsSoAView = ::vertexFinder::PixelVertexWorkSpaceSoAView;
0025
0026 ALPAKA_FN_ACC ALPAKA_FN_INLINE void sortByPt2(Acc1D const& acc, VtxSoAView& data, TrkSoAView& trkdata, WsSoAView& ws) {
0027 auto nt = ws.ntrks();
0028 float const* __restrict__ ptt2 = ws.ptt2();
0029 uint32_t const& nvFinal = data.nvFinal();
0030
0031 int32_t const* __restrict__ iv = ws.iv();
0032 float* __restrict__ ptv2 = data.ptv2();
0033 uint16_t* __restrict__ sortInd = data.sortInd();
0034
0035 if (nvFinal < 1)
0036 return;
0037
0038
0039 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0040 trkdata.idv()[ws.itrk()[i]] = iv[i];
0041 };
0042
0043
0044 for (auto i : cms::alpakatools::uniform_elements(acc, nvFinal)) {
0045 ptv2[i] = 0;
0046 };
0047 alpaka::syncBlockThreads(acc);
0048
0049 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0050 if (iv[i] <= 9990) {
0051 alpaka::atomicAdd(acc, &ptv2[iv[i]], ptt2[i], alpaka::hierarchy::Blocks{});
0052 }
0053 };
0054 alpaka::syncBlockThreads(acc);
0055
0056 const uint32_t threadIdxLocal(alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u]);
0057 if (1 == nvFinal) {
0058 if (threadIdxLocal == 0)
0059 sortInd[0] = 0;
0060 return;
0061 }
0062
0063 if constexpr (not cms::alpakatools::requires_single_thread_per_block_v<Acc1D>) {
0064 auto& sws = alpaka::declareSharedVar<uint16_t[1024], __COUNTER__>(acc);
0065
0066 cms::alpakatools::radixSort<Acc1D, float, 2>(acc, ptv2, sortInd, sws, nvFinal);
0067 } else {
0068 for (uint16_t i = 0; i < nvFinal; ++i)
0069 sortInd[i] = i;
0070 std::sort(sortInd, sortInd + nvFinal, [&](auto i, auto j) { return ptv2[i] < ptv2[j]; });
0071 }
0072 }
0073
0074 class SortByPt2Kernel {
0075 public:
0076 ALPAKA_FN_ACC void operator()(Acc1D const& acc, VtxSoAView pdata, TrkSoAView ptrkdata, WsSoAView pws) const {
0077 sortByPt2(acc, pdata, ptrkdata, pws);
0078 }
0079 };
0080
0081 }
0082
0083 #endif