Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:05:23

0001 #include <alpaka/alpaka.hpp>
0002 
0003 #include "DataFormats/VertexSoA/interface/ZVertexDevice.h"
0004 #include "DataFormats/VertexSoA/interface/ZVertexHost.h"
0005 #include "DataFormats/VertexSoA/interface/alpaka/ZVertexSoACollection.h"
0006 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0007 #include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
0008 
0009 namespace ALPAKA_ACCELERATOR_NAMESPACE::testZVertexSoAT {
0010 
0011   class TestFillKernel {
0012   public:
0013     template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>>
0014     ALPAKA_FN_ACC void operator()(TAcc const& acc, reco::ZVertexSoAView zvertex_view) const {
0015       if (cms::alpakatools::once_per_grid(acc)) {
0016         zvertex_view.nvFinal() = 420;
0017       }
0018 
0019       for (int32_t j : cms::alpakatools::uniform_elements(acc, zvertex_view.metadata().size())) {
0020         zvertex_view[j].idv() = (int16_t)j;
0021         zvertex_view[j].zv() = (float)j;
0022         zvertex_view[j].wv() = (float)j;
0023         zvertex_view[j].chi2() = (float)j;
0024         zvertex_view[j].ptv2() = (float)j;
0025         zvertex_view[j].ndof() = (int32_t)j;
0026         zvertex_view[j].sortInd() = (uint16_t)j;
0027       }
0028     }
0029   };
0030 
0031   class TestVerifyKernel {
0032   public:
0033     template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>>
0034     ALPAKA_FN_ACC void operator()(TAcc const& acc, reco::ZVertexSoAView zvertex_view) const {
0035       if (cms::alpakatools::once_per_grid(acc)) {
0036         ALPAKA_ASSERT_ACC(zvertex_view.nvFinal() == 420);
0037       }
0038 
0039       for (int32_t j : cms::alpakatools::uniform_elements(acc, zvertex_view.nvFinal())) {
0040         assert(zvertex_view[j].idv() == j);
0041         assert(zvertex_view[j].zv() - (float)j < 0.0001);
0042         assert(zvertex_view[j].wv() - (float)j < 0.0001);
0043         assert(zvertex_view[j].chi2() - (float)j < 0.0001);
0044         assert(zvertex_view[j].ptv2() - (float)j < 0.0001);
0045         assert(zvertex_view[j].ndof() == j);
0046         assert(zvertex_view[j].sortInd() == uint32_t(j));
0047       }
0048     }
0049   };
0050 
0051   void runKernels(reco::ZVertexSoAView zvertex_view, Queue& queue) {
0052     uint32_t items = 64;
0053     uint32_t groups = cms::alpakatools::divide_up_by(zvertex_view.metadata().size(), items);
0054     auto workDiv = cms::alpakatools::make_workdiv<Acc1D>(groups, items);
0055     alpaka::exec<Acc1D>(queue, workDiv, TestFillKernel{}, zvertex_view);
0056     alpaka::exec<Acc1D>(queue, workDiv, TestVerifyKernel{}, zvertex_view);
0057   }
0058 
0059 }  // namespace ALPAKA_ACCELERATOR_NAMESPACE::testZVertexSoAT