Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-09-13 22:52:46

0001 #ifndef RecoVertex_PixelVertexFinding_plugins_alpaka_sortByPt2_h
0002 #define RecoVertex_PixelVertexFinding_plugins_alpaka_sortByPt2_h
0003 
0004 #include <algorithm>
0005 #include <array>
0006 #include <cmath>
0007 #include <cstdint>
0008 
0009 #include <alpaka/alpaka.hpp>
0010 
0011 #include "DataFormats/VertexSoA/interface/ZVertexSoA.h"
0012 #include "HeterogeneousCore/AlpakaInterface/interface/HistoContainer.h"
0013 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0014 #include "HeterogeneousCore/AlpakaInterface/interface/radixSort.h"
0015 #include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
0016 #include "RecoVertex/PixelVertexFinding/interface/PixelVertexWorkSpaceLayout.h"
0017 
0018 #include "vertexFinder.h"
0019 
0020 namespace ALPAKA_ACCELERATOR_NAMESPACE::vertexFinder {
0021 
0022   using VtxSoAView = ::reco::ZVertexSoAView;
0023   using TrkSoAView = ::reco::ZVertexTracksSoAView;
0024   using WsSoAView = ::vertexFinder::PixelVertexWorkSpaceSoAView;
0025 
0026   ALPAKA_FN_ACC ALPAKA_FN_INLINE void sortByPt2(Acc1D const& acc, VtxSoAView& data, TrkSoAView& trkdata, WsSoAView& ws) {
0027     auto nt = ws.ntrks();
0028     float const* __restrict__ ptt2 = ws.ptt2();
0029     uint32_t const& nvFinal = data.nvFinal();
0030 
0031     int32_t const* __restrict__ iv = ws.iv();
0032     float* __restrict__ ptv2 = data.ptv2();
0033     uint16_t* __restrict__ sortInd = data.sortInd();
0034 
0035     if (nvFinal < 1)
0036       return;
0037 
0038     // fill indexing
0039     for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0040       trkdata.idv()[ws.itrk()[i]] = iv[i];
0041     };
0042 
0043     // can be done asynchronously at the end of previous event
0044     for (auto i : cms::alpakatools::uniform_elements(acc, nvFinal)) {
0045       ptv2[i] = 0;
0046     };
0047     alpaka::syncBlockThreads(acc);
0048 
0049     for (auto i : cms::alpakatools::uniform_elements(acc, nt)) {
0050       if (iv[i] <= 9990) {
0051         alpaka::atomicAdd(acc, &ptv2[iv[i]], ptt2[i], alpaka::hierarchy::Blocks{});
0052       }
0053     };
0054     alpaka::syncBlockThreads(acc);
0055 
0056     const uint32_t threadIdxLocal(alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u]);
0057     if (1 == nvFinal) {
0058       if (threadIdxLocal == 0)
0059         sortInd[0] = 0;
0060       return;
0061     }
0062 
0063     if constexpr (not cms::alpakatools::requires_single_thread_per_block_v<Acc1D>) {
0064       auto& sws = alpaka::declareSharedVar<uint16_t[1024], __COUNTER__>(acc);
0065       // sort using only 16 bits
0066       cms::alpakatools::radixSort<Acc1D, float, 2>(acc, ptv2, sortInd, sws, nvFinal);
0067     } else {
0068       for (uint16_t i = 0; i < nvFinal; ++i)
0069         sortInd[i] = i;
0070       std::sort(sortInd, sortInd + nvFinal, [&](auto i, auto j) { return ptv2[i] < ptv2[j]; });
0071     }
0072   }
0073 
0074   class SortByPt2Kernel {
0075   public:
0076     ALPAKA_FN_ACC void operator()(Acc1D const& acc, VtxSoAView pdata, TrkSoAView ptrkdata, WsSoAView pws) const {
0077       sortByPt2(acc, pdata, ptrkdata, pws);
0078     }
0079   };
0080 
0081 }  // namespace ALPAKA_ACCELERATOR_NAMESPACE::vertexFinder
0082 
0083 #endif  // RecoVertex_PixelVertexFinding_plugins_alpaka_sortByPt2_h