Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:28:08

0001 #include "RecoTracker/FinalTrackSelectors/interface/TrackMVAClassifier.h"
0002 
0003 #include "FWCore/Framework/interface/EventSetup.h"
0004 #include "FWCore/Framework/interface/ESHandle.h"
0005 
0006 #include "DataFormats/TrackReco/interface/Track.h"
0007 #include "DataFormats/VertexReco/interface/Vertex.h"
0008 
0009 #include "RecoTracker/FinalTrackSelectors/interface/getBestVertex.h"
0010 
0011 //from lwtnn
0012 #include "TrackingTools/Records/interface/TrackingComponentsRecord.h"
0013 #include "lwtnn/LightweightNeuralNetwork.hh"
0014 
0015 namespace {
0016   struct lwtnn {
0017     lwtnn(const edm::ParameterSet& cfg, edm::ConsumesCollector iC)
0018         : lwtnnLabel_(cfg.getParameter<std::string>("lwtnnLabel")),
0019           lwtnnToken_(iC.esConsumes(edm::ESInputTag("", lwtnnLabel_))) {}
0020 
0021     static const char* name() { return "TrackLwtnnClassifier"; }
0022 
0023     static void fillDescriptions(edm::ParameterSetDescription& desc) {
0024       desc.add<std::string>("lwtnnLabel", "trackSelectionLwtnn");
0025     }
0026 
0027     void beginStream() {}
0028     void initEvent(const edm::EventSetup& es) { neuralNetwork_ = &es.getData(lwtnnToken_); }
0029 
0030     std::pair<float, bool> operator()(reco::Track const& trk,
0031                                       reco::BeamSpot const& beamSpot,
0032                                       reco::VertexCollection const& vertices,
0033                                       lwt::ValueMap& inputs) const {
0034       // lwt::ValueMap is typedef for std::map<std::string, double>
0035       //
0036       // It is cached per event to avoid constructing the map for each
0037       // track while keeping the operator() interface const.
0038 
0039       Point bestVertex = getBestVertex(trk, vertices);
0040 
0041       inputs["trk_pt"] = trk.pt();
0042       inputs["trk_eta"] = trk.eta();
0043       inputs["trk_lambda"] = trk.lambda();
0044       inputs["trk_dxy"] = trk.dxy(beamSpot.position());  // Training done without taking absolute value
0045       inputs["trk_dz"] = trk.dz(beamSpot.position());    // Training done without taking absolute value
0046       inputs["trk_dxyClosestPV"] = trk.dxy(bestVertex);  // Training done without taking absolute value
0047       // Training done without taking absolute value
0048       inputs["trk_dzClosestPVNorm"] = std::max(-0.2, std::min(trk.dz(bestVertex), 0.2));
0049       inputs["trk_ptErr"] = trk.ptError();
0050       inputs["trk_etaErr"] = trk.etaError();
0051       inputs["trk_lambdaErr"] = trk.lambdaError();
0052       inputs["trk_dxyErr"] = trk.dxyError();
0053       inputs["trk_dzErr"] = trk.dzError();
0054       inputs["trk_nChi2"] = trk.normalizedChi2();
0055       inputs["trk_ndof"] = trk.ndof();
0056       inputs["trk_nInvalid"] = trk.hitPattern().numberOfLostHits(reco::HitPattern::TRACK_HITS);
0057       inputs["trk_nPixel"] = trk.hitPattern().numberOfValidPixelHits();
0058       inputs["trk_nStrip"] = trk.hitPattern().numberOfValidStripHits();
0059       inputs["trk_nPixelLay"] = trk.hitPattern().pixelLayersWithMeasurement();
0060       inputs["trk_nStripLay"] = trk.hitPattern().stripLayersWithMeasurement();
0061       inputs["trk_n3DLay"] = (trk.hitPattern().numberOfValidStripLayersWithMonoAndStereo() +
0062                               trk.hitPattern().pixelLayersWithMeasurement());
0063       inputs["trk_nLostLay"] = trk.hitPattern().trackerLayersWithoutMeasurement(reco::HitPattern::TRACK_HITS);
0064       inputs["trk_algo"] = trk.algo();
0065 
0066       auto out = neuralNetwork_->compute(inputs);
0067       // there should only one output
0068       if (out.size() != 1)
0069         throw cms::Exception("LogicError") << "Expecting exactly one output from NN, got " << out.size();
0070 
0071       float output = 2.0 * out.begin()->second - 1.0;
0072 
0073       //Check if the network is known to be unreliable in that part of phase space. Hard cut values
0074       //correspond to rare tracks known to be difficult for the Deep Neural Network classifier
0075 
0076       bool isReliable = true;
0077       //T1qqqq
0078       if (std::abs(inputs["trk_dxy"]) >= 0.1 && inputs["trk_etaErr"] < 0.003 && inputs["trk_dxyErr"] < 0.03 &&
0079           inputs["trk_ndof"] > 3) {
0080         isReliable = false;
0081       }
0082       //T5qqqqLL
0083       if ((inputs["trk_pt"] > 100.0) && (inputs["trk_nChi2"] < 4.0) && (inputs["trk_etaErr"] < 0.001)) {
0084         isReliable = false;
0085       }
0086 
0087       std::pair<float, bool> return_(output, isReliable);
0088       return return_;
0089     }
0090 
0091     std::string lwtnnLabel_;
0092     edm::ESGetToken<lwt::LightweightNeuralNetwork, TrackingComponentsRecord> lwtnnToken_;
0093     const lwt::LightweightNeuralNetwork* neuralNetwork_;
0094   };
0095 
0096   using TrackLwtnnClassifier = TrackMVAClassifier<lwtnn, lwt::ValueMap>;
0097 }  // namespace
0098 
0099 #include "FWCore/PluginManager/interface/ModuleDef.h"
0100 #include "FWCore/Framework/interface/MakerMacros.h"
0101 
0102 DEFINE_FWK_MODULE(TrackLwtnnClassifier);