File indexing completed on 2025-02-14 03:16:57
0001 #ifndef RecoVertex_PixelVertexFinding_plugins_alpaka_sortByPt2_h
0002 #define RecoVertex_PixelVertexFinding_plugins_alpaka_sortByPt2_h
0003
0004 #include <algorithm>
0005 #include <array>
0006 #include <cmath>
0007 #include <cstdint>
0008
0009 #include <alpaka/alpaka.hpp>
0010
0011 #include "DataFormats/VertexSoA/interface/ZVertexSoA.h"
0012 #include "HeterogeneousCore/AlpakaInterface/interface/HistoContainer.h"
0013 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0014 #include "HeterogeneousCore/AlpakaInterface/interface/radixSort.h"
0015 #include "HeterogeneousCore/AlpakaInterface/interface/warpsize.h"
0016 #include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
0017 #include "RecoVertex/PixelVertexFinding/interface/PixelVertexWorkSpaceLayout.h"
0018
0019 #include "vertexFinder.h"
0020
0021 namespace ALPAKA_ACCELERATOR_NAMESPACE::vertexFinder {
0022
0023 using VtxSoAView = ::reco::ZVertexSoAView;
0024 using TrkSoAView = ::reco::ZVertexTracksSoAView;
0025 using WsSoAView = ::vertexFinder::PixelVertexWorkSpaceSoAView;
0026
0027 ALPAKA_FN_ACC ALPAKA_FN_INLINE void sortByPt2(Acc1D const& acc, VtxSoAView& data, TrkSoAView& trkdata, WsSoAView& ws) {
0028 auto nt = ws.ntrks();
0029 float const* __restrict__ ptt2 = ws.ptt2();
0030 uint32_t const& nvFinal = data.nvFinal();
0031
0032 int32_t const* __restrict__ iv = ws.iv();
0033 float* __restrict__ ptv2 = data.ptv2();
0034 uint16_t* __restrict__ sortInd = data.sortInd();
0035
0036 if (nvFinal < 1)
0037 return;
0038
0039
0040 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0041 trkdata.idv()[ws.itrk()[i]] = iv[i];
0042 };
0043
0044
0045 for (auto i : cms::alpakatools::uniform_elements(acc, nvFinal)) {
0046 ptv2[i] = 0;
0047 };
0048 alpaka::syncBlockThreads(acc);
0049
0050 for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0051 if (iv[i] <= 9990) {
0052 alpaka::atomicAdd(acc, &ptv2[iv[i]], ptt2[i], alpaka::hierarchy::Blocks{});
0053 }
0054 };
0055 alpaka::syncBlockThreads(acc);
0056
0057 const uint32_t threadIdxLocal(alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u]);
0058 if (1 == nvFinal) {
0059 if (threadIdxLocal == 0)
0060 sortInd[0] = 0;
0061 return;
0062 }
0063
0064 if constexpr (not cms::alpakatools::requires_single_thread_per_block_v<Acc1D>) {
0065 constexpr int warpSize = cms::alpakatools::warpSize;
0066 auto& sws = alpaka::declareSharedVar<uint16_t[warpSize * warpSize], __COUNTER__>(acc);
0067
0068 cms::alpakatools::radixSort<Acc1D, float, 2>(acc, ptv2, sortInd, sws, nvFinal);
0069 } else {
0070 for (uint16_t i = 0; i < nvFinal; ++i)
0071 sortInd[i] = i;
0072 std::sort(sortInd, sortInd + nvFinal, [&](auto i, auto j) { return ptv2[i] < ptv2[j]; });
0073 }
0074 }
0075
0076 class SortByPt2Kernel {
0077 public:
0078 ALPAKA_FN_ACC void operator()(Acc1D const& acc, VtxSoAView pdata, TrkSoAView ptrkdata, WsSoAView pws) const {
0079 sortByPt2(acc, pdata, ptrkdata, pws);
0080 }
0081 };
0082
0083 }
0084
0085 #endif