Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:28:06

0001 #include "RecoTracker/FinalTrackSelectors/interface/TrackMVAClassifier.h"
0002 
0003 #include "FWCore/Framework/interface/EventSetup.h"
0004 #include "FWCore/Framework/interface/ESHandle.h"
0005 #include "CondFormats/DataRecord/interface/GBRWrapperRcd.h"
0006 
0007 #include "DataFormats/TrackReco/interface/Track.h"
0008 #include "DataFormats/VertexReco/interface/Vertex.h"
0009 #include "RecoTracker/FinalTrackSelectors/interface/getBestVertex.h"
0010 
0011 #include <limits>
0012 
0013 #include "TFile.h"
0014 
0015 namespace {
0016 
0017   template <bool PROMPT>
0018   struct mva {
0019     mva(const edm::ParameterSet &cfg, edm::ConsumesCollector iC)
0020         : forestLabel_(cfg.getParameter<std::string>("GBRForestLabel")),
0021           dbFileName_(cfg.getParameter<std::string>("GBRForestFileName")),
0022           useForestFromDB_((!forestLabel_.empty()) && dbFileName_.empty()) {
0023       if (useForestFromDB_) {
0024         forestToken_ = iC.esConsumes(edm::ESInputTag("", forestLabel_));
0025       }
0026     }
0027 
0028     void beginStream() {
0029       if (!dbFileName_.empty()) {
0030         TFile gbrfile(dbFileName_.c_str());
0031         forestFromFile_.reset((GBRForest *)gbrfile.Get(forestLabel_.c_str()));
0032       }
0033     }
0034 
0035     void initEvent(const edm::EventSetup &es) {
0036       forest_ = forestFromFile_.get();
0037       if (useForestFromDB_) {
0038         forest_ = &es.getData(forestToken_);
0039       }
0040     }
0041 
0042     float operator()(reco::Track const &trk,
0043                      reco::BeamSpot const &beamSpot,
0044                      reco::VertexCollection const &vertices) const {
0045       auto tmva_pt_ = trk.pt();
0046       auto tmva_ndof_ = trk.ndof();
0047       auto tmva_nlayers_ = trk.hitPattern().trackerLayersWithMeasurement();
0048       auto tmva_nlayers3D_ =
0049           trk.hitPattern().pixelLayersWithMeasurement() + trk.hitPattern().numberOfValidStripLayersWithMonoAndStereo();
0050       auto tmva_nlayerslost_ = trk.hitPattern().trackerLayersWithoutMeasurement(reco::HitPattern::TRACK_HITS);
0051       float chi2n = trk.normalizedChi2();
0052       float chi2n_no1Dmod = chi2n;
0053 
0054       int count1dhits = 0;
0055       for (auto ith = trk.recHitsBegin(); ith != trk.recHitsEnd(); ++ith) {
0056         const auto &hit = *(*ith);
0057         if (hit.dimension() == 1)
0058           ++count1dhits;
0059       }
0060 
0061       if (count1dhits > 0) {
0062         float chi2 = trk.chi2();
0063         float ndof = trk.ndof();
0064         chi2n = (chi2 + count1dhits) / float(ndof + count1dhits);
0065       }
0066       auto tmva_chi2n_ = chi2n;
0067       auto tmva_chi2n_no1dmod_ = chi2n_no1Dmod;
0068       auto tmva_eta_ = trk.eta();
0069       auto tmva_relpterr_ = float(trk.ptError()) / std::max(float(trk.pt()), 0.000001f);
0070       auto tmva_nhits_ = trk.numberOfValidHits();
0071       int lostIn = trk.hitPattern().numberOfLostHits(reco::HitPattern::MISSING_INNER_HITS);
0072       int lostOut = trk.hitPattern().numberOfLostHits(reco::HitPattern::MISSING_OUTER_HITS);
0073       auto tmva_minlost_ = std::min(lostIn, lostOut);
0074       auto tmva_lostmidfrac_ = static_cast<float>(trk.numberOfLostHits()) /
0075                                static_cast<float>(trk.numberOfValidHits() + trk.numberOfLostHits());
0076 
0077       float gbrVals_[PROMPT ? 16 : 12];
0078       gbrVals_[0] = tmva_pt_;
0079       gbrVals_[1] = tmva_lostmidfrac_;
0080       gbrVals_[2] = tmva_minlost_;
0081       gbrVals_[3] = tmva_nhits_;
0082       gbrVals_[4] = tmva_relpterr_;
0083       gbrVals_[5] = tmva_eta_;
0084       gbrVals_[6] = tmva_chi2n_no1dmod_;
0085       gbrVals_[7] = tmva_chi2n_;
0086       gbrVals_[8] = tmva_nlayerslost_;
0087       gbrVals_[9] = tmva_nlayers3D_;
0088       gbrVals_[10] = tmva_nlayers_;
0089       gbrVals_[11] = tmva_ndof_;
0090 
0091       if (PROMPT) {
0092         auto tmva_absd0_ = std::abs(trk.dxy(beamSpot.position()));
0093         auto tmva_absdz_ = std::abs(trk.dz(beamSpot.position()));
0094         Point bestVertex = getBestVertex(trk, vertices);
0095         auto tmva_absd0PV_ = std::abs(trk.dxy(bestVertex));
0096         auto tmva_absdzPV_ = std::abs(trk.dz(bestVertex));
0097 
0098         gbrVals_[12] = tmva_absd0PV_;
0099         gbrVals_[13] = tmva_absdzPV_;
0100         gbrVals_[14] = tmva_absdz_;
0101         gbrVals_[15] = tmva_absd0_;
0102       }
0103 
0104       return forest_->GetClassifier(gbrVals_);
0105     }
0106 
0107     static const char *name();
0108 
0109     static void fillDescriptions(edm::ParameterSetDescription &desc) {
0110       desc.add<std::string>("GBRForestLabel", std::string());
0111       desc.add<std::string>("GBRForestFileName", std::string());
0112     }
0113 
0114     std::unique_ptr<GBRForest> forestFromFile_;
0115     const GBRForest *forest_ = nullptr;  // owned by somebody else
0116     const std::string forestLabel_;
0117     const std::string dbFileName_;
0118     const bool useForestFromDB_;
0119     edm::ESGetToken<GBRForest, GBRWrapperRcd> forestToken_;
0120   };
0121 
0122   using TrackMVAClassifierDetached = TrackMVAClassifier<mva<false>>;
0123   using TrackMVAClassifierPrompt = TrackMVAClassifier<mva<true>>;
0124   template <>
0125   const char *mva<false>::name() {
0126     return "TrackMVAClassifierDetached";
0127   }
0128   template <>
0129   const char *mva<true>::name() {
0130     return "TrackMVAClassifierPrompt";
0131   }
0132 
0133 }  // namespace
0134 
0135 #include "FWCore/PluginManager/interface/ModuleDef.h"
0136 #include "FWCore/Framework/interface/MakerMacros.h"
0137 
0138 DEFINE_FWK_MODULE(TrackMVAClassifierDetached);
0139 DEFINE_FWK_MODULE(TrackMVAClassifierPrompt);