File indexing completed on 2024-04-06 12:28:06
0001 #include "RecoTracker/FinalTrackSelectors/interface/TrackMVAClassifier.h"
0002
0003 #include "FWCore/Framework/interface/EventSetup.h"
0004 #include "FWCore/Framework/interface/ESHandle.h"
0005 #include "CondFormats/DataRecord/interface/GBRWrapperRcd.h"
0006
0007 #include "DataFormats/TrackReco/interface/Track.h"
0008 #include "DataFormats/VertexReco/interface/Vertex.h"
0009 #include "RecoTracker/FinalTrackSelectors/interface/getBestVertex.h"
0010
0011 #include <limits>
0012
0013 #include "TFile.h"
0014
0015 namespace {
0016
0017 template <bool PROMPT>
0018 struct mva {
0019 mva(const edm::ParameterSet &cfg, edm::ConsumesCollector iC)
0020 : forestLabel_(cfg.getParameter<std::string>("GBRForestLabel")),
0021 dbFileName_(cfg.getParameter<std::string>("GBRForestFileName")),
0022 useForestFromDB_((!forestLabel_.empty()) && dbFileName_.empty()) {
0023 if (useForestFromDB_) {
0024 forestToken_ = iC.esConsumes(edm::ESInputTag("", forestLabel_));
0025 }
0026 }
0027
0028 void beginStream() {
0029 if (!dbFileName_.empty()) {
0030 TFile gbrfile(dbFileName_.c_str());
0031 forestFromFile_.reset((GBRForest *)gbrfile.Get(forestLabel_.c_str()));
0032 }
0033 }
0034
0035 void initEvent(const edm::EventSetup &es) {
0036 forest_ = forestFromFile_.get();
0037 if (useForestFromDB_) {
0038 forest_ = &es.getData(forestToken_);
0039 }
0040 }
0041
0042 float operator()(reco::Track const &trk,
0043 reco::BeamSpot const &beamSpot,
0044 reco::VertexCollection const &vertices) const {
0045 auto tmva_pt_ = trk.pt();
0046 auto tmva_ndof_ = trk.ndof();
0047 auto tmva_nlayers_ = trk.hitPattern().trackerLayersWithMeasurement();
0048 auto tmva_nlayers3D_ =
0049 trk.hitPattern().pixelLayersWithMeasurement() + trk.hitPattern().numberOfValidStripLayersWithMonoAndStereo();
0050 auto tmva_nlayerslost_ = trk.hitPattern().trackerLayersWithoutMeasurement(reco::HitPattern::TRACK_HITS);
0051 float chi2n = trk.normalizedChi2();
0052 float chi2n_no1Dmod = chi2n;
0053
0054 int count1dhits = 0;
0055 for (auto ith = trk.recHitsBegin(); ith != trk.recHitsEnd(); ++ith) {
0056 const auto &hit = *(*ith);
0057 if (hit.dimension() == 1)
0058 ++count1dhits;
0059 }
0060
0061 if (count1dhits > 0) {
0062 float chi2 = trk.chi2();
0063 float ndof = trk.ndof();
0064 chi2n = (chi2 + count1dhits) / float(ndof + count1dhits);
0065 }
0066 auto tmva_chi2n_ = chi2n;
0067 auto tmva_chi2n_no1dmod_ = chi2n_no1Dmod;
0068 auto tmva_eta_ = trk.eta();
0069 auto tmva_relpterr_ = float(trk.ptError()) / std::max(float(trk.pt()), 0.000001f);
0070 auto tmva_nhits_ = trk.numberOfValidHits();
0071 int lostIn = trk.hitPattern().numberOfLostHits(reco::HitPattern::MISSING_INNER_HITS);
0072 int lostOut = trk.hitPattern().numberOfLostHits(reco::HitPattern::MISSING_OUTER_HITS);
0073 auto tmva_minlost_ = std::min(lostIn, lostOut);
0074 auto tmva_lostmidfrac_ = static_cast<float>(trk.numberOfLostHits()) /
0075 static_cast<float>(trk.numberOfValidHits() + trk.numberOfLostHits());
0076
0077 float gbrVals_[PROMPT ? 16 : 12];
0078 gbrVals_[0] = tmva_pt_;
0079 gbrVals_[1] = tmva_lostmidfrac_;
0080 gbrVals_[2] = tmva_minlost_;
0081 gbrVals_[3] = tmva_nhits_;
0082 gbrVals_[4] = tmva_relpterr_;
0083 gbrVals_[5] = tmva_eta_;
0084 gbrVals_[6] = tmva_chi2n_no1dmod_;
0085 gbrVals_[7] = tmva_chi2n_;
0086 gbrVals_[8] = tmva_nlayerslost_;
0087 gbrVals_[9] = tmva_nlayers3D_;
0088 gbrVals_[10] = tmva_nlayers_;
0089 gbrVals_[11] = tmva_ndof_;
0090
0091 if (PROMPT) {
0092 auto tmva_absd0_ = std::abs(trk.dxy(beamSpot.position()));
0093 auto tmva_absdz_ = std::abs(trk.dz(beamSpot.position()));
0094 Point bestVertex = getBestVertex(trk, vertices);
0095 auto tmva_absd0PV_ = std::abs(trk.dxy(bestVertex));
0096 auto tmva_absdzPV_ = std::abs(trk.dz(bestVertex));
0097
0098 gbrVals_[12] = tmva_absd0PV_;
0099 gbrVals_[13] = tmva_absdzPV_;
0100 gbrVals_[14] = tmva_absdz_;
0101 gbrVals_[15] = tmva_absd0_;
0102 }
0103
0104 return forest_->GetClassifier(gbrVals_);
0105 }
0106
0107 static const char *name();
0108
0109 static void fillDescriptions(edm::ParameterSetDescription &desc) {
0110 desc.add<std::string>("GBRForestLabel", std::string());
0111 desc.add<std::string>("GBRForestFileName", std::string());
0112 }
0113
0114 std::unique_ptr<GBRForest> forestFromFile_;
0115 const GBRForest *forest_ = nullptr;
0116 const std::string forestLabel_;
0117 const std::string dbFileName_;
0118 const bool useForestFromDB_;
0119 edm::ESGetToken<GBRForest, GBRWrapperRcd> forestToken_;
0120 };
0121
0122 using TrackMVAClassifierDetached = TrackMVAClassifier<mva<false>>;
0123 using TrackMVAClassifierPrompt = TrackMVAClassifier<mva<true>>;
0124 template <>
0125 const char *mva<false>::name() {
0126 return "TrackMVAClassifierDetached";
0127 }
0128 template <>
0129 const char *mva<true>::name() {
0130 return "TrackMVAClassifierPrompt";
0131 }
0132
0133 }
0134
0135 #include "FWCore/PluginManager/interface/ModuleDef.h"
0136 #include "FWCore/Framework/interface/MakerMacros.h"
0137
0138 DEFINE_FWK_MODULE(TrackMVAClassifierDetached);
0139 DEFINE_FWK_MODULE(TrackMVAClassifierPrompt);