Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-12-05 02:48:05

0001 #include "RecoTracker/LSTCore/interface/alpaka/LST.h"
0002 
0003 #include "LSTEvent.h"
0004 
0005 using namespace ALPAKA_ACCELERATOR_NAMESPACE::lst;
0006 
0007 #include "Math/Vector3D.h"
0008 #include "Math/VectorUtil.h"
0009 using XYZVector = ROOT::Math::XYZVector;
0010 
0011 namespace {
0012   XYZVector calculateR3FromPCA(const XYZVector& p3, float dxy, float dz) {
0013     const float pt = p3.rho();
0014     const float p = p3.r();
0015     const float vz = dz * pt * pt / p / p;
0016 
0017     const float vx = -dxy * p3.y() / pt - p3.x() / p * p3.z() / p * dz;
0018     const float vy = dxy * p3.x() / pt - p3.y() / p * p3.z() / p * dz;
0019     return {vx, vy, vz};
0020   }
0021 
0022   using namespace ALPAKA_ACCELERATOR_NAMESPACE::lst;
0023   std::vector<unsigned int> getHitIdxs(short trackCandidateType,
0024                                        Params_pT5::ArrayUxHits const& tcHitIndices,
0025                                        unsigned int const* hitIndices) {
0026     std::vector<unsigned int> hits;
0027 
0028     unsigned int maxNHits = 0;
0029     if (trackCandidateType == LSTObjType::pT5)
0030       maxNHits = Params_pT5::kHits;
0031     else if (trackCandidateType == LSTObjType::pT3)
0032       maxNHits = Params_pT3::kHits;
0033     else if (trackCandidateType == LSTObjType::T5)
0034       maxNHits = Params_T5::kHits;
0035     else if (trackCandidateType == LSTObjType::pLS)
0036       maxNHits = Params_pLS::kHits;
0037 
0038     for (unsigned int i = 0; i < maxNHits; i++) {
0039       unsigned int hitIdxDev = tcHitIndices[i];
0040       unsigned int hitIdx =
0041           (trackCandidateType == LSTObjType::pLS)
0042               ? hitIdxDev
0043               : hitIndices[hitIdxDev];  // Hit indices are stored differently in the standalone for pLS.
0044 
0045       // For p objects, the 3rd and 4th hit maybe the same,
0046       // due to the way pLS hits are stored in the standalone.
0047       // This is because pixel seeds can be either triplets or quadruplets.
0048       if (trackCandidateType != LSTObjType::T5 && hits.size() == 3 &&
0049           hits.back() == hitIdx)  // Remove duplicate 4th hits.
0050         continue;
0051 
0052       hits.push_back(hitIdx);
0053     }
0054 
0055     return hits;
0056   }
0057 
0058 }  // namespace
0059 
0060 void LST::prepareInput(std::vector<float> const& see_px,
0061                        std::vector<float> const& see_py,
0062                        std::vector<float> const& see_pz,
0063                        std::vector<float> const& see_dxy,
0064                        std::vector<float> const& see_dz,
0065                        std::vector<float> const& see_ptErr,
0066                        std::vector<float> const& see_etaErr,
0067                        std::vector<float> const& see_stateTrajGlbX,
0068                        std::vector<float> const& see_stateTrajGlbY,
0069                        std::vector<float> const& see_stateTrajGlbZ,
0070                        std::vector<float> const& see_stateTrajGlbPx,
0071                        std::vector<float> const& see_stateTrajGlbPy,
0072                        std::vector<float> const& see_stateTrajGlbPz,
0073                        std::vector<int> const& see_q,
0074                        std::vector<std::vector<int>> const& see_hitIdx,
0075                        std::vector<unsigned int> const& ph2_detId,
0076                        std::vector<float> const& ph2_x,
0077                        std::vector<float> const& ph2_y,
0078                        std::vector<float> const& ph2_z,
0079                        float const ptCut) {
0080   in_trkX_.clear();
0081   in_trkY_.clear();
0082   in_trkZ_.clear();
0083   in_hitId_.clear();
0084   in_hitIdxs_.clear();
0085   in_hitIndices_vec0_.clear();
0086   in_hitIndices_vec1_.clear();
0087   in_hitIndices_vec2_.clear();
0088   in_hitIndices_vec3_.clear();
0089   in_deltaPhi_vec_.clear();
0090   in_ptIn_vec_.clear();
0091   in_ptErr_vec_.clear();
0092   in_px_vec_.clear();
0093   in_py_vec_.clear();
0094   in_pz_vec_.clear();
0095   in_eta_vec_.clear();
0096   in_etaErr_vec_.clear();
0097   in_phi_vec_.clear();
0098   in_charge_vec_.clear();
0099   in_seedIdx_vec_.clear();
0100   in_superbin_vec_.clear();
0101   in_pixelType_vec_.clear();
0102   in_isQuad_vec_.clear();
0103 
0104   unsigned int count = 0;
0105   auto n_see = see_stateTrajGlbPx.size();
0106   in_px_vec_.reserve(n_see);
0107   in_py_vec_.reserve(n_see);
0108   in_pz_vec_.reserve(n_see);
0109   in_hitIndices_vec0_.reserve(n_see);
0110   in_hitIndices_vec1_.reserve(n_see);
0111   in_hitIndices_vec2_.reserve(n_see);
0112   in_hitIndices_vec3_.reserve(n_see);
0113   in_ptIn_vec_.reserve(n_see);
0114   in_ptErr_vec_.reserve(n_see);
0115   in_etaErr_vec_.reserve(n_see);
0116   in_eta_vec_.reserve(n_see);
0117   in_phi_vec_.reserve(n_see);
0118   in_charge_vec_.reserve(n_see);
0119   in_seedIdx_vec_.reserve(n_see);
0120   in_deltaPhi_vec_.reserve(n_see);
0121   in_trkX_ = ph2_x;
0122   in_trkY_ = ph2_y;
0123   in_trkZ_ = ph2_z;
0124   in_hitId_ = ph2_detId;
0125   in_hitIdxs_.resize(ph2_detId.size());
0126 
0127   std::iota(in_hitIdxs_.begin(), in_hitIdxs_.end(), 0);
0128   const int hit_size = in_trkX_.size();
0129 
0130   for (size_t iSeed = 0; iSeed < n_see; iSeed++) {
0131     XYZVector p3LH(see_stateTrajGlbPx[iSeed], see_stateTrajGlbPy[iSeed], see_stateTrajGlbPz[iSeed]);
0132     float ptIn = p3LH.rho();
0133     float eta = p3LH.eta();
0134     float ptErr = see_ptErr[iSeed];
0135 
0136     if ((ptIn > ptCut - 2 * ptErr)) {
0137       XYZVector r3LH(see_stateTrajGlbX[iSeed], see_stateTrajGlbY[iSeed], see_stateTrajGlbZ[iSeed]);
0138       XYZVector p3PCA(see_px[iSeed], see_py[iSeed], see_pz[iSeed]);
0139       XYZVector r3PCA(calculateR3FromPCA(p3PCA, see_dxy[iSeed], see_dz[iSeed]));
0140 
0141       // The charge could be used directly in the line below
0142       float pixelSegmentDeltaPhiChange = ROOT::Math::VectorUtil::DeltaPhi(p3LH, r3LH);
0143       float etaErr = see_etaErr[iSeed];
0144       float px = p3LH.x();
0145       float py = p3LH.y();
0146       float pz = p3LH.z();
0147 
0148       int charge = see_q[iSeed];
0149       PixelType pixtype = PixelType::kInvalid;
0150 
0151       if (ptIn >= 2.0)
0152         pixtype = PixelType::kHighPt;
0153       else if (ptIn >= (ptCut - 2 * ptErr) and ptIn < 2.0) {
0154         if (pixelSegmentDeltaPhiChange >= 0)
0155           pixtype = PixelType::kLowPtPosCurv;
0156         else
0157           pixtype = PixelType::kLowPtNegCurv;
0158       } else
0159         continue;
0160 
0161       unsigned int hitIdx0 = hit_size + count;
0162       count++;
0163       unsigned int hitIdx1 = hit_size + count;
0164       count++;
0165       unsigned int hitIdx2 = hit_size + count;
0166       count++;
0167       unsigned int hitIdx3;
0168       if (see_hitIdx[iSeed].size() <= 3)
0169         hitIdx3 = hitIdx2;
0170       else {
0171         hitIdx3 = hit_size + count;
0172         count++;
0173       }
0174 
0175       in_trkX_.push_back(r3PCA.x());
0176       in_trkY_.push_back(r3PCA.y());
0177       in_trkZ_.push_back(r3PCA.z());
0178       in_trkX_.push_back(p3PCA.rho());
0179       float p3PCA_Eta = p3PCA.eta();
0180       in_trkY_.push_back(p3PCA_Eta);
0181       float p3PCA_Phi = p3PCA.phi();
0182       in_trkZ_.push_back(p3PCA_Phi);
0183       in_trkX_.push_back(r3LH.x());
0184       in_trkY_.push_back(r3LH.y());
0185       in_trkZ_.push_back(r3LH.z());
0186       in_hitId_.push_back(1);
0187       in_hitId_.push_back(1);
0188       in_hitId_.push_back(1);
0189       if (see_hitIdx[iSeed].size() > 3) {
0190         in_trkX_.push_back(r3LH.x());
0191         in_trkY_.push_back(see_dxy[iSeed]);
0192         in_trkZ_.push_back(see_dz[iSeed]);
0193         in_hitId_.push_back(1);
0194       }
0195       in_px_vec_.push_back(px);
0196       in_py_vec_.push_back(py);
0197       in_pz_vec_.push_back(pz);
0198 
0199       in_hitIndices_vec0_.push_back(hitIdx0);
0200       in_hitIndices_vec1_.push_back(hitIdx1);
0201       in_hitIndices_vec2_.push_back(hitIdx2);
0202       in_hitIndices_vec3_.push_back(hitIdx3);
0203       in_ptIn_vec_.push_back(ptIn);
0204       in_ptErr_vec_.push_back(ptErr);
0205       in_etaErr_vec_.push_back(etaErr);
0206       in_eta_vec_.push_back(eta);
0207       float phi = p3LH.phi();
0208       in_phi_vec_.push_back(phi);
0209       in_charge_vec_.push_back(charge);
0210       in_seedIdx_vec_.push_back(iSeed);
0211       in_deltaPhi_vec_.push_back(pixelSegmentDeltaPhiChange);
0212 
0213       in_hitIdxs_.push_back(see_hitIdx[iSeed][0]);
0214       in_hitIdxs_.push_back(see_hitIdx[iSeed][1]);
0215       in_hitIdxs_.push_back(see_hitIdx[iSeed][2]);
0216       char isQuad = false;
0217       if (see_hitIdx[iSeed].size() > 3) {
0218         isQuad = true;
0219         in_hitIdxs_.push_back(see_hitIdx[iSeed][3]);
0220       }
0221       float neta = 25.;
0222       float nphi = 72.;
0223       float nz = 25.;
0224       int etabin = (p3PCA_Eta + 2.6) / ((2 * 2.6) / neta);
0225       int phibin = (p3PCA_Phi + kPi) / ((2. * kPi) / nphi);
0226       int dzbin = (see_dz[iSeed] + 30) / (2 * 30 / nz);
0227       int isuperbin = (nz * nphi) * etabin + (nz)*phibin + dzbin;
0228       in_superbin_vec_.push_back(isuperbin);
0229       in_pixelType_vec_.push_back(pixtype);
0230       in_isQuad_vec_.push_back(isQuad);
0231     }
0232   }
0233 }
0234 
0235 void LST::getOutput(LSTEvent& event) {
0236   out_tc_hitIdxs_.clear();
0237   out_tc_len_.clear();
0238   out_tc_seedIdx_.clear();
0239   out_tc_trackCandidateType_.clear();
0240 
0241   auto const hits = event.getHits<HitsSoA>(/*inCMSSW*/ true, /*sync*/ false);  // sync on next line
0242   auto const& trackCandidates = event.getTrackCandidates(/*inCMSSW*/ true, /*sync*/ true);
0243 
0244   unsigned int nTrackCandidates = trackCandidates.nTrackCandidates();
0245 
0246   for (unsigned int idx = 0; idx < nTrackCandidates; idx++) {
0247     short trackCandidateType = trackCandidates.trackCandidateType()[idx];
0248     std::vector<unsigned int> hit_idx = getHitIdxs(trackCandidateType, trackCandidates.hitIndices()[idx], hits.idxs());
0249 
0250     out_tc_hitIdxs_.push_back(hit_idx);
0251     out_tc_len_.push_back(hit_idx.size());
0252     out_tc_seedIdx_.push_back(trackCandidates.pixelSeedIndex()[idx]);
0253     out_tc_trackCandidateType_.push_back(trackCandidateType);
0254   }
0255 }
0256 
0257 void LST::run(Queue& queue,
0258               bool verbose,
0259               float const ptCut,
0260               LSTESData<Device> const* deviceESData,
0261               std::vector<float> const& see_px,
0262               std::vector<float> const& see_py,
0263               std::vector<float> const& see_pz,
0264               std::vector<float> const& see_dxy,
0265               std::vector<float> const& see_dz,
0266               std::vector<float> const& see_ptErr,
0267               std::vector<float> const& see_etaErr,
0268               std::vector<float> const& see_stateTrajGlbX,
0269               std::vector<float> const& see_stateTrajGlbY,
0270               std::vector<float> const& see_stateTrajGlbZ,
0271               std::vector<float> const& see_stateTrajGlbPx,
0272               std::vector<float> const& see_stateTrajGlbPy,
0273               std::vector<float> const& see_stateTrajGlbPz,
0274               std::vector<int> const& see_q,
0275               std::vector<std::vector<int>> const& see_hitIdx,
0276               std::vector<unsigned int> const& ph2_detId,
0277               std::vector<float> const& ph2_x,
0278               std::vector<float> const& ph2_y,
0279               std::vector<float> const& ph2_z,
0280               bool no_pls_dupclean,
0281               bool tc_pls_triplets) {
0282   auto event = LSTEvent(verbose, ptCut, queue, deviceESData);
0283   prepareInput(see_px,
0284                see_py,
0285                see_pz,
0286                see_dxy,
0287                see_dz,
0288                see_ptErr,
0289                see_etaErr,
0290                see_stateTrajGlbX,
0291                see_stateTrajGlbY,
0292                see_stateTrajGlbZ,
0293                see_stateTrajGlbPx,
0294                see_stateTrajGlbPy,
0295                see_stateTrajGlbPz,
0296                see_q,
0297                see_hitIdx,
0298                ph2_detId,
0299                ph2_x,
0300                ph2_y,
0301                ph2_z,
0302                ptCut);
0303 
0304   event.addHitToEvent(in_trkX_, in_trkY_, in_trkZ_, in_hitId_, in_hitIdxs_);
0305   event.addPixelSegmentToEvent(in_hitIndices_vec0_,
0306                                in_hitIndices_vec1_,
0307                                in_hitIndices_vec2_,
0308                                in_hitIndices_vec3_,
0309                                in_deltaPhi_vec_,
0310                                in_ptIn_vec_,
0311                                in_ptErr_vec_,
0312                                in_px_vec_,
0313                                in_py_vec_,
0314                                in_pz_vec_,
0315                                in_eta_vec_,
0316                                in_etaErr_vec_,
0317                                in_phi_vec_,
0318                                in_charge_vec_,
0319                                in_seedIdx_vec_,
0320                                in_superbin_vec_,
0321                                in_pixelType_vec_,
0322                                in_isQuad_vec_);
0323   event.createMiniDoublets();
0324   if (verbose) {
0325     alpaka::wait(queue);  // event calls are asynchronous: wait before printing
0326     printf("# of Mini-doublets produced: %d\n", event.getNumberOfMiniDoublets());
0327     printf("# of Mini-doublets produced barrel layer 1: %d\n", event.getNumberOfMiniDoubletsByLayerBarrel(0));
0328     printf("# of Mini-doublets produced barrel layer 2: %d\n", event.getNumberOfMiniDoubletsByLayerBarrel(1));
0329     printf("# of Mini-doublets produced barrel layer 3: %d\n", event.getNumberOfMiniDoubletsByLayerBarrel(2));
0330     printf("# of Mini-doublets produced barrel layer 4: %d\n", event.getNumberOfMiniDoubletsByLayerBarrel(3));
0331     printf("# of Mini-doublets produced barrel layer 5: %d\n", event.getNumberOfMiniDoubletsByLayerBarrel(4));
0332     printf("# of Mini-doublets produced barrel layer 6: %d\n", event.getNumberOfMiniDoubletsByLayerBarrel(5));
0333     printf("# of Mini-doublets produced endcap layer 1: %d\n", event.getNumberOfMiniDoubletsByLayerEndcap(0));
0334     printf("# of Mini-doublets produced endcap layer 2: %d\n", event.getNumberOfMiniDoubletsByLayerEndcap(1));
0335     printf("# of Mini-doublets produced endcap layer 3: %d\n", event.getNumberOfMiniDoubletsByLayerEndcap(2));
0336     printf("# of Mini-doublets produced endcap layer 4: %d\n", event.getNumberOfMiniDoubletsByLayerEndcap(3));
0337     printf("# of Mini-doublets produced endcap layer 5: %d\n", event.getNumberOfMiniDoubletsByLayerEndcap(4));
0338   }
0339 
0340   event.createSegmentsWithModuleMap();
0341   if (verbose) {
0342     alpaka::wait(queue);  // event calls are asynchronous: wait before printing
0343     printf("# of Segments produced: %d\n", event.getNumberOfSegments());
0344     printf("# of Segments produced layer 1-2:  %d\n", event.getNumberOfSegmentsByLayerBarrel(0));
0345     printf("# of Segments produced layer 2-3:  %d\n", event.getNumberOfSegmentsByLayerBarrel(1));
0346     printf("# of Segments produced layer 3-4:  %d\n", event.getNumberOfSegmentsByLayerBarrel(2));
0347     printf("# of Segments produced layer 4-5:  %d\n", event.getNumberOfSegmentsByLayerBarrel(3));
0348     printf("# of Segments produced layer 5-6:  %d\n", event.getNumberOfSegmentsByLayerBarrel(4));
0349     printf("# of Segments produced endcap layer 1:  %d\n", event.getNumberOfSegmentsByLayerEndcap(0));
0350     printf("# of Segments produced endcap layer 2:  %d\n", event.getNumberOfSegmentsByLayerEndcap(1));
0351     printf("# of Segments produced endcap layer 3:  %d\n", event.getNumberOfSegmentsByLayerEndcap(2));
0352     printf("# of Segments produced endcap layer 4:  %d\n", event.getNumberOfSegmentsByLayerEndcap(3));
0353     printf("# of Segments produced endcap layer 5:  %d\n", event.getNumberOfSegmentsByLayerEndcap(4));
0354   }
0355 
0356   event.createTriplets();
0357   if (verbose) {
0358     alpaka::wait(queue);  // event calls are asynchronous: wait before printing
0359     printf("# of T3s produced: %d\n", event.getNumberOfTriplets());
0360     printf("# of T3s produced layer 1-2-3: %d\n", event.getNumberOfTripletsByLayerBarrel(0));
0361     printf("# of T3s produced layer 2-3-4: %d\n", event.getNumberOfTripletsByLayerBarrel(1));
0362     printf("# of T3s produced layer 3-4-5: %d\n", event.getNumberOfTripletsByLayerBarrel(2));
0363     printf("# of T3s produced layer 4-5-6: %d\n", event.getNumberOfTripletsByLayerBarrel(3));
0364     printf("# of T3s produced endcap layer 1-2-3: %d\n", event.getNumberOfTripletsByLayerEndcap(0));
0365     printf("# of T3s produced endcap layer 2-3-4: %d\n", event.getNumberOfTripletsByLayerEndcap(1));
0366     printf("# of T3s produced endcap layer 3-4-5: %d\n", event.getNumberOfTripletsByLayerEndcap(2));
0367     printf("# of T3s produced endcap layer 1: %d\n", event.getNumberOfTripletsByLayerEndcap(0));
0368     printf("# of T3s produced endcap layer 2: %d\n", event.getNumberOfTripletsByLayerEndcap(1));
0369     printf("# of T3s produced endcap layer 3: %d\n", event.getNumberOfTripletsByLayerEndcap(2));
0370     printf("# of T3s produced endcap layer 4: %d\n", event.getNumberOfTripletsByLayerEndcap(3));
0371     printf("# of T3s produced endcap layer 5: %d\n", event.getNumberOfTripletsByLayerEndcap(4));
0372   }
0373 
0374   event.createQuintuplets();
0375   if (verbose) {
0376     alpaka::wait(queue);  // event calls are asynchronous: wait before printing
0377     printf("# of Quintuplets produced: %d\n", event.getNumberOfQuintuplets());
0378     printf("# of Quintuplets produced layer 1-2-3-4-5-6: %d\n", event.getNumberOfQuintupletsByLayerBarrel(0));
0379     printf("# of Quintuplets produced layer 2: %d\n", event.getNumberOfQuintupletsByLayerBarrel(1));
0380     printf("# of Quintuplets produced layer 3: %d\n", event.getNumberOfQuintupletsByLayerBarrel(2));
0381     printf("# of Quintuplets produced layer 4: %d\n", event.getNumberOfQuintupletsByLayerBarrel(3));
0382     printf("# of Quintuplets produced layer 5: %d\n", event.getNumberOfQuintupletsByLayerBarrel(4));
0383     printf("# of Quintuplets produced layer 6: %d\n", event.getNumberOfQuintupletsByLayerBarrel(5));
0384     printf("# of Quintuplets produced endcap layer 1: %d\n", event.getNumberOfQuintupletsByLayerEndcap(0));
0385     printf("# of Quintuplets produced endcap layer 2: %d\n", event.getNumberOfQuintupletsByLayerEndcap(1));
0386     printf("# of Quintuplets produced endcap layer 3: %d\n", event.getNumberOfQuintupletsByLayerEndcap(2));
0387     printf("# of Quintuplets produced endcap layer 4: %d\n", event.getNumberOfQuintupletsByLayerEndcap(3));
0388     printf("# of Quintuplets produced endcap layer 5: %d\n", event.getNumberOfQuintupletsByLayerEndcap(4));
0389   }
0390 
0391   event.pixelLineSegmentCleaning(no_pls_dupclean);
0392 
0393   event.createPixelQuintuplets();
0394   if (verbose) {
0395     alpaka::wait(queue);  // event calls are asynchronous: wait before printing
0396     printf("# of Pixel Quintuplets produced: %d\n", event.getNumberOfPixelQuintuplets());
0397   }
0398 
0399   event.createPixelTriplets();
0400   if (verbose) {
0401     alpaka::wait(queue);  // event calls are asynchronous: wait before printing
0402     printf("# of Pixel T3s produced: %d\n", event.getNumberOfPixelTriplets());
0403   }
0404 
0405   event.createTrackCandidates(no_pls_dupclean, tc_pls_triplets);
0406   if (verbose) {
0407     alpaka::wait(queue);  // event calls are asynchronous: wait before printing
0408     printf("# of TrackCandidates produced: %d\n", event.getNumberOfTrackCandidates());
0409     printf("        # of Pixel TrackCandidates produced: %d\n", event.getNumberOfPixelTrackCandidates());
0410     printf("        # of pT5 TrackCandidates produced: %d\n", event.getNumberOfPT5TrackCandidates());
0411     printf("        # of pT3 TrackCandidates produced: %d\n", event.getNumberOfPT3TrackCandidates());
0412     printf("        # of pLS TrackCandidates produced: %d\n", event.getNumberOfPLSTrackCandidates());
0413     printf("        # of T5 TrackCandidates produced: %d\n", event.getNumberOfT5TrackCandidates());
0414   }
0415 
0416   getOutput(event);
0417 }