File indexing completed on 2024-04-20 02:32:09
0001 #ifndef RecoParticleFlow_PFClusterProducer_plugins_alpaka_PFClusterECLCC_h
0002 #define RecoParticleFlow_PFClusterProducer_plugins_alpaka_PFClusterECLCC_h
0003
0004 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0005 #include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
0006 #include "RecoParticleFlow/PFClusterProducer/interface/alpaka/PFClusteringVarsDeviceCollection.h"
0007 #include "RecoParticleFlow/PFClusterProducer/interface/alpaka/PFClusteringEdgeVarsDeviceCollection.h"
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061 namespace ALPAKA_ACCELERATOR_NAMESPACE {
0062
0063
0064
0065 ALPAKA_FN_ACC inline int representative(const int idx,
0066 reco::PFClusteringVarsDeviceCollection::View pfClusteringVars) {
0067 int curr = pfClusteringVars[idx].pfrh_topoId();
0068 if (curr != idx) {
0069 int next, prev = idx;
0070 while (curr > (next = pfClusteringVars[curr].pfrh_topoId())) {
0071 pfClusteringVars[prev].pfrh_topoId() = next;
0072 prev = curr;
0073 curr = next;
0074 }
0075 }
0076 return curr;
0077 }
0078
0079
0080 class ECLCCInit {
0081 public:
0082 template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>>
0083 ALPAKA_FN_ACC void operator()(const TAcc& acc,
0084 reco::PFRecHitHostCollection::ConstView pfRecHits,
0085 reco::PFClusteringVarsDeviceCollection::View pfClusteringVars,
0086 reco::PFClusteringEdgeVarsDeviceCollection::View pfClusteringEdgeVars) const {
0087 const int nRH = pfRecHits.size();
0088 for (int v : cms::alpakatools::uniform_elements(acc, nRH)) {
0089 const int beg = pfClusteringEdgeVars[v].pfrh_edgeIdx();
0090 const int end = pfClusteringEdgeVars[v + 1].pfrh_edgeIdx();
0091 int m = v;
0092 int i = beg;
0093 while ((m == v) && (i < end)) {
0094 m = std::min(m, pfClusteringEdgeVars[i].pfrh_edgeList());
0095 i++;
0096 }
0097 pfClusteringVars[v].pfrh_topoId() = m;
0098 }
0099 }
0100 };
0101
0102
0103
0104 class ECLCCCompute1 {
0105 public:
0106 template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>>
0107 ALPAKA_FN_ACC void operator()(const TAcc& acc,
0108 reco::PFRecHitHostCollection::ConstView pfRecHits,
0109 reco::PFClusteringVarsDeviceCollection::View pfClusteringVars,
0110 reco::PFClusteringEdgeVarsDeviceCollection::View pfClusteringEdgeVars) const {
0111 const int nRH = pfRecHits.size();
0112
0113 for (int v : cms::alpakatools::uniform_elements(acc, nRH)) {
0114 const int vstat = pfClusteringVars[v].pfrh_topoId();
0115 if (v != vstat) {
0116 const int beg = pfClusteringEdgeVars[v].pfrh_edgeIdx();
0117 const int end = pfClusteringEdgeVars[v + 1].pfrh_edgeIdx();
0118 int vstat = representative(v, pfClusteringVars);
0119 for (int i = beg; i < end; i++) {
0120 const int nli = pfClusteringEdgeVars[i].pfrh_edgeList();
0121 if (v > nli) {
0122 int ostat = representative(nli, pfClusteringVars);
0123 bool repeat;
0124 do {
0125 repeat = false;
0126 if (vstat != ostat) {
0127 int ret;
0128 if (vstat < ostat) {
0129 if ((ret = alpaka::atomicCas(acc, &pfClusteringVars[ostat].pfrh_topoId(), ostat, vstat)) != ostat) {
0130 ostat = ret;
0131 repeat = true;
0132 }
0133 } else {
0134 if ((ret = alpaka::atomicCas(acc, &pfClusteringVars[vstat].pfrh_topoId(), vstat, ostat)) != vstat) {
0135 vstat = ret;
0136 repeat = true;
0137 }
0138 }
0139 }
0140 } while (repeat);
0141 }
0142 }
0143 }
0144 }
0145 }
0146 };
0147
0148
0149 class ECLCCFlatten {
0150 public:
0151 template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>>
0152 ALPAKA_FN_ACC void operator()(const TAcc& acc,
0153 reco::PFRecHitHostCollection::ConstView pfRecHits,
0154 reco::PFClusteringVarsDeviceCollection::View pfClusteringVars,
0155 reco::PFClusteringEdgeVarsDeviceCollection::View pfClusteringEdgeVars) const {
0156 const int nRH = pfRecHits.size();
0157
0158 for (int v : cms::alpakatools::uniform_elements(acc, nRH)) {
0159 int next, vstat = pfClusteringVars[v].pfrh_topoId();
0160 const int old = vstat;
0161 while (vstat > (next = pfClusteringVars[vstat].pfrh_topoId())) {
0162 vstat = next;
0163 }
0164 if (old != vstat)
0165 pfClusteringVars[v].pfrh_topoId() = vstat;
0166 }
0167 }
0168 };
0169
0170
0171
0172 }
0173
0174 #endif