File indexing completed on 2024-04-06 12:28:08
0001 #include "RecoTracker/FinalTrackSelectors/interface/TrackMVAClassifier.h"
0002
0003 #include "FWCore/Framework/interface/EventSetup.h"
0004 #include "FWCore/Framework/interface/ESHandle.h"
0005
0006 #include "DataFormats/TrackReco/interface/Track.h"
0007 #include "DataFormats/VertexReco/interface/Vertex.h"
0008
0009 #include "RecoTracker/FinalTrackSelectors/interface/getBestVertex.h"
0010
0011
0012 #include "TrackingTools/Records/interface/TrackingComponentsRecord.h"
0013 #include "lwtnn/LightweightNeuralNetwork.hh"
0014
0015 namespace {
0016 struct lwtnn {
0017 lwtnn(const edm::ParameterSet& cfg, edm::ConsumesCollector iC)
0018 : lwtnnLabel_(cfg.getParameter<std::string>("lwtnnLabel")),
0019 lwtnnToken_(iC.esConsumes(edm::ESInputTag("", lwtnnLabel_))) {}
0020
0021 static const char* name() { return "TrackLwtnnClassifier"; }
0022
0023 static void fillDescriptions(edm::ParameterSetDescription& desc) {
0024 desc.add<std::string>("lwtnnLabel", "trackSelectionLwtnn");
0025 }
0026
0027 void beginStream() {}
0028 void initEvent(const edm::EventSetup& es) { neuralNetwork_ = &es.getData(lwtnnToken_); }
0029
0030 std::pair<float, bool> operator()(reco::Track const& trk,
0031 reco::BeamSpot const& beamSpot,
0032 reco::VertexCollection const& vertices,
0033 lwt::ValueMap& inputs) const {
0034
0035
0036
0037
0038
0039 Point bestVertex = getBestVertex(trk, vertices);
0040
0041 inputs["trk_pt"] = trk.pt();
0042 inputs["trk_eta"] = trk.eta();
0043 inputs["trk_lambda"] = trk.lambda();
0044 inputs["trk_dxy"] = trk.dxy(beamSpot.position());
0045 inputs["trk_dz"] = trk.dz(beamSpot.position());
0046 inputs["trk_dxyClosestPV"] = trk.dxy(bestVertex);
0047
0048 inputs["trk_dzClosestPVNorm"] = std::max(-0.2, std::min(trk.dz(bestVertex), 0.2));
0049 inputs["trk_ptErr"] = trk.ptError();
0050 inputs["trk_etaErr"] = trk.etaError();
0051 inputs["trk_lambdaErr"] = trk.lambdaError();
0052 inputs["trk_dxyErr"] = trk.dxyError();
0053 inputs["trk_dzErr"] = trk.dzError();
0054 inputs["trk_nChi2"] = trk.normalizedChi2();
0055 inputs["trk_ndof"] = trk.ndof();
0056 inputs["trk_nInvalid"] = trk.hitPattern().numberOfLostHits(reco::HitPattern::TRACK_HITS);
0057 inputs["trk_nPixel"] = trk.hitPattern().numberOfValidPixelHits();
0058 inputs["trk_nStrip"] = trk.hitPattern().numberOfValidStripHits();
0059 inputs["trk_nPixelLay"] = trk.hitPattern().pixelLayersWithMeasurement();
0060 inputs["trk_nStripLay"] = trk.hitPattern().stripLayersWithMeasurement();
0061 inputs["trk_n3DLay"] = (trk.hitPattern().numberOfValidStripLayersWithMonoAndStereo() +
0062 trk.hitPattern().pixelLayersWithMeasurement());
0063 inputs["trk_nLostLay"] = trk.hitPattern().trackerLayersWithoutMeasurement(reco::HitPattern::TRACK_HITS);
0064 inputs["trk_algo"] = trk.algo();
0065
0066 auto out = neuralNetwork_->compute(inputs);
0067
0068 if (out.size() != 1)
0069 throw cms::Exception("LogicError") << "Expecting exactly one output from NN, got " << out.size();
0070
0071 float output = 2.0 * out.begin()->second - 1.0;
0072
0073
0074
0075
0076 bool isReliable = true;
0077
0078 if (std::abs(inputs["trk_dxy"]) >= 0.1 && inputs["trk_etaErr"] < 0.003 && inputs["trk_dxyErr"] < 0.03 &&
0079 inputs["trk_ndof"] > 3) {
0080 isReliable = false;
0081 }
0082
0083 if ((inputs["trk_pt"] > 100.0) && (inputs["trk_nChi2"] < 4.0) && (inputs["trk_etaErr"] < 0.001)) {
0084 isReliable = false;
0085 }
0086
0087 std::pair<float, bool> return_(output, isReliable);
0088 return return_;
0089 }
0090
0091 std::string lwtnnLabel_;
0092 edm::ESGetToken<lwt::LightweightNeuralNetwork, TrackingComponentsRecord> lwtnnToken_;
0093 const lwt::LightweightNeuralNetwork* neuralNetwork_;
0094 };
0095
0096 using TrackLwtnnClassifier = TrackMVAClassifier<lwtnn, lwt::ValueMap>;
0097 }
0098
0099 #include "FWCore/PluginManager/interface/ModuleDef.h"
0100 #include "FWCore/Framework/interface/MakerMacros.h"
0101
0102 DEFINE_FWK_MODULE(TrackLwtnnClassifier);