Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:28:25

0001 #include "RecoTracker/MkFitCore/standalone/TrackExtra.h"
0002 #include "RecoTracker/MkFitCore/standalone/ConfigStandalone.h"
0003 
0004 //#define DEBUG
0005 #include "RecoTracker/MkFitCore/src/Debug.h"
0006 
0007 namespace mkfit {
0008 
0009   //==============================================================================
0010   // TrackExtra
0011   //==============================================================================
0012 
0013   void TrackExtra::findMatchingSeedHits(const Track& reco_trk,
0014                                         const Track& seed_trk,
0015                                         const std::vector<HitVec>& layerHits) {
0016     // outer loop over reco hits
0017     for (int reco_ihit = 0; reco_ihit < reco_trk.nTotalHits(); ++reco_ihit) {
0018       const int reco_lyr = reco_trk.getHitLyr(reco_ihit);
0019       const int reco_idx = reco_trk.getHitIdx(reco_ihit);
0020 
0021       // ensure layer exists
0022       if (reco_lyr < 0)
0023         continue;
0024 
0025       // make sure it is a real hit
0026       if ((reco_idx < 0) || (static_cast<size_t>(reco_idx) >= layerHits[reco_lyr].size()))
0027         continue;
0028 
0029       // inner loop over seed hits
0030       for (int seed_ihit = 0; seed_ihit < seed_trk.nTotalHits(); ++seed_ihit) {
0031         const int seed_lyr = seed_trk.getHitLyr(seed_ihit);
0032         const int seed_idx = seed_trk.getHitIdx(seed_ihit);
0033 
0034         // ensure layer exists
0035         if (seed_lyr < 0)
0036           continue;
0037 
0038         // check that lyrs are the same
0039         if (reco_lyr != seed_lyr)
0040           continue;
0041 
0042         // make sure it is a real hit
0043         if ((seed_idx < 0) || (static_cast<size_t>(seed_idx) >= layerHits[seed_lyr].size()))
0044           continue;
0045 
0046         // finally, emplace if idx is the same
0047         if (reco_idx == seed_idx)
0048           matchedSeedHits_.emplace_back(seed_idx, seed_lyr);
0049       }
0050     }
0051   }
0052 
0053   bool TrackExtra::isSeedHit(const int lyr, const int idx) const {
0054     return (std::find_if(matchedSeedHits_.begin(), matchedSeedHits_.end(), [=](const auto& matchedSeedHit) {
0055               return ((matchedSeedHit.layer == lyr) && (matchedSeedHit.index == idx));
0056             }) != matchedSeedHits_.end());
0057   }
0058 
0059   int TrackExtra::modifyRefTrackID(const int foundHits,
0060                                    const int minHits,
0061                                    const TrackVec& reftracks,
0062                                    const int trueID,
0063                                    const int duplicate,
0064                                    int refTrackID) {
0065     // Modify refTrackID based on nMinHits and findability
0066     if (duplicate) {
0067       refTrackID = -10;
0068     } else {
0069       if (refTrackID >= 0) {
0070         if (reftracks[refTrackID].isFindable()) {
0071           if (foundHits < minHits)
0072             refTrackID = -2;
0073           //else                     refTrackID = refTrackID;
0074         } else  // ref track is not findable
0075         {
0076           if (foundHits < minHits)
0077             refTrackID = -3;
0078           else
0079             refTrackID = -4;
0080         }
0081       } else if (refTrackID == -1) {
0082         if (trueID >= 0) {
0083           if (reftracks[trueID].isFindable()) {
0084             if (foundHits < minHits)
0085               refTrackID = -5;
0086             //else                     refTrackID = refTrackID;
0087           } else  // sim track is not findable
0088           {
0089             if (foundHits < minHits)
0090               refTrackID = -6;
0091             else
0092               refTrackID = -7;
0093           }
0094         } else {
0095           if (foundHits < minHits)
0096             refTrackID = -8;
0097           else
0098             refTrackID = -9;
0099         }
0100       }
0101     }
0102     return refTrackID;
0103   }
0104 
0105   // Generic 50% reco to sim matching after seed
0106   void TrackExtra::setMCTrackIDInfo(const Track& trk,
0107                                     const std::vector<HitVec>& layerHits,
0108                                     const MCHitInfoVec& globalHitInfo,
0109                                     const TrackVec& simtracks,
0110                                     const bool isSeed,
0111                                     const bool isPure) {
0112     dprintf("TrackExtra::setMCTrackIDInfo for track with label %d, total hits %d, found hits %d\n",
0113             trk.label(),
0114             trk.nTotalHits(),
0115             trk.nFoundHits());
0116 
0117     std::vector<int> mcTrackIDs;         // vector of found mcTrackIDs on reco track
0118     int nSeedHits = nMatchedSeedHits();  // count seed hits
0119 
0120     // loop over all hits stored in reco track, storing valid mcTrackIDs
0121     for (int ihit = 0; ihit < trk.nTotalHits(); ++ihit) {
0122       const int lyr = trk.getHitLyr(ihit);
0123       const int idx = trk.getHitIdx(ihit);
0124 
0125       // ensure layer exists
0126       if (lyr < 0)
0127         continue;
0128 
0129       // skip seed layers (unless, of course, we are validating the seed tracks themselves)
0130       if (!Config::mtvLikeValidation && !isSeed && isSeedHit(lyr, idx))
0131         continue;
0132 
0133       // make sure it is a real hit
0134       if ((idx >= 0) && (static_cast<size_t>(idx) < layerHits[lyr].size())) {
0135         // get mchitid and then get mcTrackID
0136         const int mchitid = layerHits[lyr][idx].mcHitID();
0137         mcTrackIDs.push_back(globalHitInfo[mchitid].mcTrackID());
0138 
0139         dprintf("  ihit=%3d   trk.hit_idx=%4d  trk.hit_lyr=%2d   mchitid=%4d  mctrkid=%3d\n",
0140                 ihit,
0141                 idx,
0142                 lyr,
0143                 mchitid,
0144                 globalHitInfo[mchitid].mcTrackID());
0145       } else {
0146         dprintf("  ihit=%3d   trk.hit_idx=%4d  trk.hit_lyr=%2d\n", ihit, idx, lyr);
0147       }
0148     }
0149 
0150     int mccount = 0;          // count up the mcTrackID with the largest count
0151     int mcTrackID = -1;       // initialize mcTrackID
0152     if (!mcTrackIDs.empty())  // protection against tracks which do not make it past the seed
0153     {
0154       // sorted list ensures that mcTrackIDs are counted properly
0155       std::sort(mcTrackIDs.begin(), mcTrackIDs.end());
0156 
0157       // don't count bad mcTrackIDs (id < 0)
0158       mcTrackIDs.erase(std::remove_if(mcTrackIDs.begin(), mcTrackIDs.end(), [](const int id) { return id < 0; }),
0159                        mcTrackIDs.end());
0160 
0161       int n_ids = mcTrackIDs.size();
0162       int i = 0;
0163       while (i < n_ids) {
0164         int j = i + 1;
0165         while (j < n_ids && mcTrackIDs[j] == mcTrackIDs[i])
0166           ++j;
0167 
0168         int n = j - i;
0169         if (mcTrackIDs[i] >= 0 && n > mccount) {
0170           mcTrackID = mcTrackIDs[i];
0171           mccount = n;
0172         }
0173         i = j;
0174       }
0175 
0176       // total found hits in hit index array, excluding seed if necessary
0177       const int nCandHits = ((Config::mtvLikeValidation || isSeed) ? trk.nFoundHits() : trk.nFoundHits() - nSeedHits);
0178 
0179       // 75% or 50% matching criterion
0180       if ((Config::mtvLikeValidation ? (4 * mccount > 3 * nCandHits) : (2 * mccount >= nCandHits))) {
0181         // require that most matched is the mcTrackID!
0182         if (isPure) {
0183           if (mcTrackID == seedID_)
0184             mcTrackID_ = mcTrackID;
0185           else
0186             mcTrackID_ = -1;  // somehow, this guy followed another simtrack!
0187         } else {
0188           mcTrackID_ = mcTrackID;
0189         }
0190       } else  // failed 50% matching criteria
0191       {
0192         mcTrackID_ = -1;
0193       }
0194 
0195       // recount matched hits for pure sim tracks if reco track is unmatched
0196       if (isPure && mcTrackID == -1) {
0197         mccount = 0;
0198         for (auto id : mcTrackIDs) {
0199           if (id == seedID_)
0200             mccount++;
0201         }
0202       }
0203 
0204       // store matched hit info
0205       nHitsMatched_ = mccount;
0206       fracHitsMatched_ = float(nHitsMatched_) / float(nCandHits);
0207 
0208       // compute dPhi
0209       dPhi_ =
0210           (mcTrackID >= 0 ? squashPhiGeneral(simtracks[mcTrackID].swimPhiToR(trk.x(), trk.y()) - trk.momPhi()) : -99.f);
0211     } else {
0212       mcTrackID_ = mcTrackID;  // defaults from -1!
0213       nHitsMatched_ = -99;
0214       fracHitsMatched_ = -99.f;
0215       dPhi_ = -99.f;
0216     }
0217 
0218     // Modify mcTrackID based on length of track (excluding seed tracks, of course) and findability
0219     if (!isSeed) {
0220       mcTrackID_ = modifyRefTrackID(trk.nFoundHits() - nSeedHits,
0221                                     Config::nMinFoundHits - nSeedHits,
0222                                     simtracks,
0223                                     (isPure ? seedID_ : -1),
0224                                     trk.getDuplicateValue(),
0225                                     mcTrackID_);
0226     }
0227 
0228     dprint("Track " << trk.label() << " best mc track " << mcTrackID_ << " count " << mccount << "/"
0229                     << trk.nFoundHits());
0230   }
0231 
0232   typedef std::pair<int, float> idchi2Pair;
0233   typedef std::vector<idchi2Pair> idchi2PairVec;
0234 
0235   inline bool sortIDsByChi2(const idchi2Pair& cand1, const idchi2Pair& cand2) { return cand1.second < cand2.second; }
0236 
0237   inline int getMatchBin(const float pt) {
0238     if (pt < 0.75f)
0239       return 0;
0240     else if (pt < 1.f)
0241       return 1;
0242     else if (pt < 2.f)
0243       return 2;
0244     else if (pt < 5.f)
0245       return 3;
0246     else if (pt < 10.f)
0247       return 4;
0248     else
0249       return 5;
0250   }
0251 
0252   void TrackExtra::setCMSSWTrackIDInfoByTrkParams(const Track& trk,
0253                                                   const std::vector<HitVec>& layerHits,
0254                                                   const TrackVec& cmsswtracks,
0255                                                   const RedTrackVec& redcmsswtracks,
0256                                                   const bool isBkFit) {
0257     // get temporary reco track params
0258     const SVector6& trkParams = trk.parameters();
0259     const SMatrixSym66& trkErrs = trk.errors();
0260 
0261     // get bin used for cuts in dphi, chi2 based on pt
0262     const int bin = getMatchBin(trk.pT());
0263 
0264     // temps needed for chi2
0265     SVector2 trkParamsR;
0266     trkParamsR[0] = trkParams[3];
0267     trkParamsR[1] = trkParams[5];
0268 
0269     SMatrixSym22 trkErrsR;
0270     trkErrsR[0][0] = trkErrs[3][3];
0271     trkErrsR[1][1] = trkErrs[5][5];
0272     trkErrsR[0][1] = trkErrs[3][5];
0273     trkErrsR[1][0] = trkErrs[5][3];
0274 
0275     // cands is vector of possible cmssw tracks we could match
0276     idchi2PairVec cands;
0277 
0278     // first check for cmmsw tracks we match by chi2
0279     for (const auto& redcmsswtrack : redcmsswtracks) {
0280       const float chi2 = std::abs(computeHelixChi2(redcmsswtrack.parameters(), trkParamsR, trkErrsR, false));
0281       if (chi2 < Config::minCMSSWMatchChi2[bin])
0282         cands.push_back(std::make_pair(redcmsswtrack.label(), chi2));
0283     }
0284 
0285     // get min chi2
0286     float minchi2 = -1e6;
0287     if (!cands.empty()) {
0288       std::sort(cands.begin(), cands.end(), sortIDsByChi2);  // in case we just want to stop at the first dPhi match
0289       minchi2 = cands.front().second;
0290     }
0291 
0292     // set up defaults
0293     int cmsswTrackID = -1;
0294     int nHitsMatched = 0;
0295     float bestdPhi = Config::minCMSSWMatchdPhi[bin];
0296     float bestchi2 = minchi2;
0297 
0298     // loop over possible cmssw tracks
0299     for (auto&& cand : cands) {
0300       // get cmssw track
0301       const auto label = cand.first;
0302       const auto& cmsswtrack = cmsswtracks[label];
0303 
0304       // get diff in track mom. phi: swim phi of cmssw track to reco track R if forward built tracks
0305       const float diffPhi =
0306           squashPhiGeneral((isBkFit ? cmsswtrack.momPhi() : cmsswtrack.swimPhiToR(trk.x(), trk.y())) - trk.momPhi());
0307 
0308       // check for best matched track by phi
0309       if (std::abs(diffPhi) < std::abs(bestdPhi)) {
0310         const HitLayerMap& hitLayerMap = redcmsswtracks[label].hitLayerMap();
0311         int matched = 0;
0312 
0313         // loop over hits on reco track
0314         for (int ihit = 0; ihit < trk.nTotalHits(); ihit++) {
0315           const int lyr = trk.getHitLyr(ihit);
0316           const int idx = trk.getHitIdx(ihit);
0317 
0318           // skip seed layers
0319           if (isSeedHit(lyr, idx))
0320             continue;
0321 
0322           // skip if bad index or cmssw track does not have that layer
0323           if (idx < 0 || !hitLayerMap.count(lyr))
0324             continue;
0325 
0326           // loop over hits in layer for the cmssw track
0327           for (auto cidx : hitLayerMap.at(lyr)) {
0328             // since we can only pick up on hit on a layer, break loop after finding hit
0329             if (cidx == idx) {
0330               matched++;
0331               break;
0332             }
0333           }
0334         }  // end loop over hits on reco track
0335 
0336         // now save the matched info
0337         bestdPhi = diffPhi;
0338         nHitsMatched = matched;
0339         cmsswTrackID = label;
0340         bestchi2 = cand.second;
0341       }  // end check over dPhi
0342     }    // end loop over cands
0343 
0344     // set cmsswTrackID
0345     cmsswTrackID_ = cmsswTrackID;  // defaults to -1!
0346     helixChi2_ = bestchi2;
0347     dPhi_ = bestdPhi;
0348 
0349     // get seed hits
0350     const int nSeedHits = nMatchedSeedHits();
0351 
0352     // Modify cmsswTrackID based on length and findability
0353     cmsswTrackID_ = modifyRefTrackID(trk.nFoundHits() - nSeedHits,
0354                                      Config::nMinFoundHits - nSeedHits,
0355                                      cmsswtracks,
0356                                      -1,
0357                                      trk.getDuplicateValue(),
0358                                      cmsswTrackID_);
0359 
0360     // other important info
0361     nHitsMatched_ = nHitsMatched;
0362     fracHitsMatched_ =
0363         float(nHitsMatched_) / float(trk.nFoundHits() - nSeedHits);  // seed hits may already be included!
0364   }
0365 
0366   void TrackExtra::setCMSSWTrackIDInfoByHits(const Track& trk,
0367                                              const LayIdxIDVecMapMap& cmsswHitIDMap,
0368                                              const TrackVec& cmsswtracks,
0369                                              const TrackExtraVec& cmsswextras,
0370                                              const RedTrackVec& redcmsswtracks,
0371                                              const int cmsswlabel) {
0372     // reminder: cmsswlabel >= 0 indicates we are using pure seeds and matching by cmsswlabel
0373 
0374     // map of cmssw labels, and hits matched to that label
0375     std::unordered_map<int, int> labelMatchMap;
0376 
0377     // loop over mkfit track hits
0378     for (int ihit = 0; ihit < trk.nTotalHits(); ihit++) {
0379       const int lyr = trk.getHitLyr(ihit);
0380       const int idx = trk.getHitIdx(ihit);
0381 
0382       if (lyr < 0 || idx < 0)
0383         continue;  // standard check
0384       if (isSeedHit(lyr, idx))
0385         continue;  // skip seed layers
0386       if (!cmsswHitIDMap.count(lyr))
0387         continue;  // make sure at least one cmssw track has this hit lyr!
0388       if (!cmsswHitIDMap.at(lyr).count(idx))
0389         continue;  // make sure at least one cmssw track has this hit id!
0390       {
0391         for (const auto label : cmsswHitIDMap.at(lyr).at(idx)) {
0392           labelMatchMap[label]++;
0393         }
0394       }
0395     }
0396 
0397     // make list of cmssw tracks that pass criteria --> could have multiple overlapping tracks!
0398     std::vector<int> labelMatchVec;
0399     for (const auto labelMatchPair : labelMatchMap) {
0400       const auto cmsswlabel = labelMatchPair.first;
0401       const auto nMatchedHits = labelMatchPair.second;
0402 
0403       // 50% matching criterion
0404       if ((2 * nMatchedHits) >= (cmsswtracks[cmsswlabel].nUniqueLayers() - cmsswextras[cmsswlabel].nMatchedSeedHits()))
0405         labelMatchVec.push_back(cmsswlabel);
0406     }
0407 
0408     // initialize tmpID for later use
0409     int cmsswTrackID = -1;
0410 
0411     // protect against no matches!
0412     if (!labelMatchVec.empty()) {
0413       // sort by best matched: most hits matched , then ratio of matches (i.e. which cmssw track is shorter)
0414       std::sort(labelMatchVec.begin(), labelMatchVec.end(), [&](const int label1, const int label2) {
0415         if (labelMatchMap[label1] == labelMatchMap[label2]) {
0416           const auto& track1 = cmsswtracks[label1];
0417           const auto& track2 = cmsswtracks[label2];
0418 
0419           const auto& extra1 = cmsswextras[label1];
0420           const auto& extra2 = cmsswextras[label2];
0421 
0422           return ((track1.nUniqueLayers() - extra1.nMatchedSeedHits()) <
0423                   (track2.nUniqueLayers() - extra2.nMatchedSeedHits()));
0424         }
0425         return labelMatchMap[label1] > labelMatchMap[label2];
0426       });
0427 
0428       // pick the longest track!
0429       cmsswTrackID = labelMatchVec.front();
0430 
0431       // set cmsswTrackID_ (if cmsswlabel >= 0, we are matching by label and label exists!)
0432       if (cmsswlabel >= 0) {
0433         if (cmsswTrackID == cmsswlabel) {
0434           cmsswTrackID_ = cmsswTrackID;
0435         } else {
0436           cmsswTrackID = cmsswlabel;  // use this for later
0437           cmsswTrackID_ = -1;
0438         }
0439       } else  // not matching by pure id
0440       {
0441         cmsswTrackID_ = cmsswTrackID;  // the longest track is matched
0442       }
0443 
0444       // set nHits matched to cmssw track
0445       nHitsMatched_ = labelMatchMap[cmsswTrackID];
0446     } else  // did not match a single cmssw track with 50% hits shared
0447     {
0448       // by default sets to -1
0449       cmsswTrackID_ = cmsswTrackID;
0450 
0451       // tmp variable
0452       int nHitsMatched = 0;
0453 
0454       // use truth info
0455       if (cmsswlabel >= 0) {
0456         cmsswTrackID = cmsswlabel;
0457         nHitsMatched = labelMatchMap[cmsswTrackID];
0458       } else {
0459         // just get the cmssw track with the most matches!
0460         for (const auto labelMatchPair : labelMatchMap) {
0461           if (labelMatchPair.second > nHitsMatched) {
0462             cmsswTrackID = labelMatchPair.first;
0463             nHitsMatched = labelMatchPair.second;
0464           }
0465         }
0466       }
0467 
0468       nHitsMatched_ = nHitsMatched;
0469     }
0470 
0471     // set chi2, dphi based on tmp cmsswTrackID
0472     if (cmsswTrackID >= 0) {
0473       // get tmps for chi2, dphi
0474       const SVector6& trkParams = trk.parameters();
0475       const SMatrixSym66& trkErrs = trk.errors();
0476 
0477       // temps needed for chi2
0478       SVector2 trkParamsR;
0479       trkParamsR[0] = trkParams[3];
0480       trkParamsR[1] = trkParams[5];
0481 
0482       SMatrixSym22 trkErrsR;
0483       trkErrsR[0][0] = trkErrs[3][3];
0484       trkErrsR[1][1] = trkErrs[5][5];
0485       trkErrsR[0][1] = trkErrs[3][5];
0486       trkErrsR[1][0] = trkErrs[5][3];
0487 
0488       // set chi2 and dphi
0489       helixChi2_ = std::abs(computeHelixChi2(redcmsswtracks[cmsswTrackID].parameters(), trkParamsR, trkErrsR, false));
0490       dPhi_ = squashPhiGeneral(cmsswtracks[cmsswTrackID].swimPhiToR(trk.x(), trk.y()) - trk.momPhi());
0491     } else {
0492       helixChi2_ = -99.f;
0493       dPhi_ = -99.f;
0494     }
0495 
0496     // get nSeedHits
0497     const int nSeedHits = nMatchedSeedHits();
0498 
0499     // Modify cmsswTrackID based on length and findability
0500     cmsswTrackID_ = modifyRefTrackID(trk.nFoundHits() - nSeedHits,
0501                                      Config::nMinFoundHits - nSeedHits,
0502                                      cmsswtracks,
0503                                      cmsswlabel,
0504                                      trk.getDuplicateValue(),
0505                                      cmsswTrackID_);
0506 
0507     // other important info
0508     fracHitsMatched_ = (cmsswTrackID >= 0 ? (float(nHitsMatched_) / float(cmsswtracks[cmsswTrackID].nUniqueLayers() -
0509                                                                           cmsswextras[cmsswTrackID].nMatchedSeedHits()))
0510                                           : 0.f);
0511   }
0512 
0513 }  // namespace mkfit