Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:28:09

0001 #include "RecoTracker/FinalTrackSelectors/interface/TrackMVAClassifier.h"
0002 
0003 #include "FWCore/Framework/interface/EventSetup.h"
0004 #include "FWCore/Framework/interface/global/EDProducer.h"
0005 #include "DataFormats/TrackReco/interface/Track.h"
0006 #include "DataFormats/VertexReco/interface/Vertex.h"
0007 #include "FWCore/Framework/interface/ConsumesCollector.h"
0008 #include "RecoTracker/FinalTrackSelectors/interface/getBestVertex.h"
0009 
0010 #include "PhysicsTools/TensorFlow/interface/TfGraphRecord.h"
0011 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0012 #include "PhysicsTools/TensorFlow/interface/TfGraphDefWrapper.h"
0013 
0014 namespace {
0015   class TfDnn {
0016   public:
0017     TfDnn(const edm::ParameterSet& cfg, edm::ConsumesCollector iC)
0018         : tfDnnLabel_(cfg.getParameter<std::string>("tfDnnLabel")),
0019           tfDnnToken_(iC.esConsumes(edm::ESInputTag("", tfDnnLabel_))),
0020           session_(nullptr),
0021           bsize_(cfg.getParameter<int>("batchSize"))
0022 
0023     {}
0024 
0025     static const char* name() { return "trackTfClassifierDefault"; }
0026 
0027     static void fillDescriptions(edm::ParameterSetDescription& desc) {
0028       desc.add<std::string>("tfDnnLabel", "trackSelectionTf");
0029       desc.add<int>("batchSize", 16);
0030     }
0031     void beginStream() {}
0032 
0033     void initEvent(const edm::EventSetup& es) {
0034       if (session_ == nullptr) {
0035         session_ = es.getData(tfDnnToken_).getSession();
0036       }
0037     }
0038 
0039     std::vector<float> operator()(reco::TrackCollection const& tracks,
0040                                   reco::BeamSpot const& beamSpot,
0041                                   reco::VertexCollection const& vertices) const {
0042       int size_in = (int)tracks.size();
0043       int nbatches = size_in / bsize_;
0044 
0045       std::vector<float> output;
0046       output.resize(size_in);
0047 
0048       tensorflow::Tensor input1(tensorflow::DT_FLOAT, {bsize_, 29});
0049       tensorflow::Tensor input2(tensorflow::DT_FLOAT, {bsize_, 1});
0050 
0051       for (auto nb = 0; nb < nbatches + 1; nb++) {
0052         for (auto nt = 0; nt < bsize_; nt++) {
0053           int itrack = nt + bsize_ * nb;
0054           if (itrack >= size_in)
0055             continue;
0056           const auto& trk = tracks[itrack];
0057 
0058           const auto& bestVertex = getBestVertex(trk, vertices);
0059 
0060           input1.matrix<float>()(nt, 0) = trk.pt();
0061           input1.matrix<float>()(nt, 1) = trk.innerMomentum().x();
0062           input1.matrix<float>()(nt, 2) = trk.innerMomentum().y();
0063           input1.matrix<float>()(nt, 3) = trk.innerMomentum().z();
0064           input1.matrix<float>()(nt, 4) = trk.innerMomentum().rho();
0065           input1.matrix<float>()(nt, 5) = trk.outerMomentum().x();
0066           input1.matrix<float>()(nt, 6) = trk.outerMomentum().y();
0067           input1.matrix<float>()(nt, 7) = trk.outerMomentum().z();
0068           input1.matrix<float>()(nt, 8) = trk.outerMomentum().rho();
0069           input1.matrix<float>()(nt, 9) = trk.ptError();
0070           input1.matrix<float>()(nt, 10) = trk.dxy(bestVertex);
0071           input1.matrix<float>()(nt, 11) = trk.dz(bestVertex);
0072           input1.matrix<float>()(nt, 12) = trk.dxy(beamSpot.position());
0073           input1.matrix<float>()(nt, 13) = trk.dz(beamSpot.position());
0074           input1.matrix<float>()(nt, 14) = trk.dxyError();
0075           input1.matrix<float>()(nt, 15) = trk.dzError();
0076           input1.matrix<float>()(nt, 16) = trk.normalizedChi2();
0077           input1.matrix<float>()(nt, 17) = trk.eta();
0078           input1.matrix<float>()(nt, 18) = trk.phi();
0079           input1.matrix<float>()(nt, 19) = trk.etaError();
0080           input1.matrix<float>()(nt, 20) = trk.phiError();
0081           input1.matrix<float>()(nt, 21) = trk.hitPattern().numberOfValidPixelHits();
0082           input1.matrix<float>()(nt, 22) = trk.hitPattern().numberOfValidStripHits();
0083           input1.matrix<float>()(nt, 23) = trk.ndof();
0084           input1.matrix<float>()(nt, 24) =
0085               trk.hitPattern().numberOfLostTrackerHits(reco::HitPattern::MISSING_INNER_HITS);
0086           input1.matrix<float>()(nt, 25) =
0087               trk.hitPattern().numberOfLostTrackerHits(reco::HitPattern::MISSING_OUTER_HITS);
0088           input1.matrix<float>()(nt, 26) =
0089               trk.hitPattern().trackerLayersTotallyOffOrBad(reco::HitPattern::MISSING_INNER_HITS);
0090           input1.matrix<float>()(nt, 27) =
0091               trk.hitPattern().trackerLayersTotallyOffOrBad(reco::HitPattern::MISSING_OUTER_HITS);
0092           input1.matrix<float>()(nt, 28) =
0093               trk.hitPattern().trackerLayersWithoutMeasurement(reco::HitPattern::TRACK_HITS);
0094 
0095           //Original algo as its own input, it will enter the graph so that it gets one-hot encoded, as is the preferred
0096           //format for categorical inputs, where the labels do not have any metric amongst them
0097           input2.matrix<float>()(nt, 0) = trk.originalAlgo();
0098         }
0099 
0100         //The names for the input tensors get locked when freezing the trained tensorflow model. The NamedTensors must
0101         //match those names
0102         tensorflow::NamedTensorList inputs;
0103         inputs.resize(2);
0104         inputs[0] = tensorflow::NamedTensor("x", input1);
0105         inputs[1] = tensorflow::NamedTensor("y", input2);
0106         std::vector<tensorflow::Tensor> outputs;
0107 
0108         //evaluate the input
0109         tensorflow::run(session_, inputs, {"Identity"}, &outputs);
0110 
0111         for (auto nt = 0; nt < bsize_; nt++) {
0112           int itrack = nt + bsize_ * nb;
0113           if (itrack >= size_in)
0114             continue;
0115           float out0 = 2.0 * outputs[0].matrix<float>()(nt, 0) - 1.0;
0116           output[itrack] = out0;
0117         }
0118       }
0119       return output;
0120     }
0121 
0122     const std::string tfDnnLabel_;
0123     const edm::ESGetToken<TfGraphDefWrapper, TfGraphRecord> tfDnnToken_;
0124     const tensorflow::Session* session_;
0125     const int bsize_;
0126   };
0127 }  // namespace
0128 
0129 template <>
0130 void trackMVAClassifierImpl::ComputeMVA<void>::operator()(::TfDnn const& mva,
0131                                                           reco::TrackCollection const& tracks,
0132                                                           reco::BeamSpot const& beamSpot,
0133                                                           reco::VertexCollection const& vertices,
0134                                                           TrackMVAClassifierBase::MVAPairCollection& mvas) {
0135   const auto& scores = mva(tracks, beamSpot, vertices);
0136   size_t current = 0;
0137 
0138   for (auto score : scores) {
0139     std::pair<float, bool> output(score, true);
0140     mvas[current++] = output;
0141   }
0142 }
0143 
0144 namespace {
0145   using TrackTfClassifier = TrackMVAClassifier<TfDnn>;
0146 }  // namespace
0147 
0148 #include "FWCore/PluginManager/interface/ModuleDef.h"
0149 #include "FWCore/Framework/interface/MakerMacros.h"
0150 
0151 DEFINE_FWK_MODULE(TrackTfClassifier);