File indexing completed on 2023-03-17 11:22:19
0001 #include "RecoTracker/FinalTrackSelectors/interface/TrackMVAClassifier.h"
0002
0003 #include "FWCore/Framework/interface/EventSetup.h"
0004 #include "FWCore/Framework/interface/global/EDProducer.h"
0005 #include "DataFormats/TrackReco/interface/Track.h"
0006 #include "DataFormats/VertexReco/interface/Vertex.h"
0007 #include "FWCore/Framework/interface/ConsumesCollector.h"
0008 #include "RecoTracker/FinalTrackSelectors/interface/getBestVertex.h"
0009
0010 #include "PhysicsTools/TensorFlow/interface/TfGraphRecord.h"
0011 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0012 #include "PhysicsTools/TensorFlow/interface/TfGraphDefWrapper.h"
0013
0014 namespace {
0015 class TfDnn {
0016 public:
0017 TfDnn(const edm::ParameterSet& cfg, edm::ConsumesCollector iC)
0018 : tfDnnLabel_(cfg.getParameter<std::string>("tfDnnLabel")),
0019 tfDnnToken_(iC.esConsumes(edm::ESInputTag("", tfDnnLabel_))),
0020 session_(nullptr),
0021 bsize_(cfg.getParameter<int>("batchSize"))
0022
0023 {}
0024
0025 static const char* name() { return "trackTfClassifierDefault"; }
0026
0027 static void fillDescriptions(edm::ParameterSetDescription& desc) {
0028 desc.add<std::string>("tfDnnLabel", "trackSelectionTf");
0029 desc.add<int>("batchSize", 16);
0030 }
0031 void beginStream() {}
0032
0033 void initEvent(const edm::EventSetup& es) {
0034 if (session_ == nullptr) {
0035 session_ = es.getData(tfDnnToken_).getSession();
0036 }
0037 }
0038
0039 std::vector<float> operator()(reco::TrackCollection const& tracks,
0040 reco::BeamSpot const& beamSpot,
0041 reco::VertexCollection const& vertices) const {
0042 int size_in = (int)tracks.size();
0043 int nbatches = size_in / bsize_;
0044
0045 std::vector<float> output;
0046 output.resize(size_in);
0047
0048 tensorflow::Tensor input1(tensorflow::DT_FLOAT, {bsize_, 29});
0049 tensorflow::Tensor input2(tensorflow::DT_FLOAT, {bsize_, 1});
0050
0051 for (auto nb = 0; nb < nbatches + 1; nb++) {
0052 for (auto nt = 0; nt < bsize_; nt++) {
0053 int itrack = nt + bsize_ * nb;
0054 if (itrack >= size_in)
0055 continue;
0056 const auto& trk = tracks[itrack];
0057
0058 const auto& bestVertex = getBestVertex(trk, vertices);
0059
0060 input1.matrix<float>()(nt, 0) = trk.pt();
0061 input1.matrix<float>()(nt, 1) = trk.innerMomentum().x();
0062 input1.matrix<float>()(nt, 2) = trk.innerMomentum().y();
0063 input1.matrix<float>()(nt, 3) = trk.innerMomentum().z();
0064 input1.matrix<float>()(nt, 4) = trk.innerMomentum().rho();
0065 input1.matrix<float>()(nt, 5) = trk.outerMomentum().x();
0066 input1.matrix<float>()(nt, 6) = trk.outerMomentum().y();
0067 input1.matrix<float>()(nt, 7) = trk.outerMomentum().z();
0068 input1.matrix<float>()(nt, 8) = trk.outerMomentum().rho();
0069 input1.matrix<float>()(nt, 9) = trk.ptError();
0070 input1.matrix<float>()(nt, 10) = trk.dxy(bestVertex);
0071 input1.matrix<float>()(nt, 11) = trk.dz(bestVertex);
0072 input1.matrix<float>()(nt, 12) = trk.dxy(beamSpot.position());
0073 input1.matrix<float>()(nt, 13) = trk.dz(beamSpot.position());
0074 input1.matrix<float>()(nt, 14) = trk.dxyError();
0075 input1.matrix<float>()(nt, 15) = trk.dzError();
0076 input1.matrix<float>()(nt, 16) = trk.normalizedChi2();
0077 input1.matrix<float>()(nt, 17) = trk.eta();
0078 input1.matrix<float>()(nt, 18) = trk.phi();
0079 input1.matrix<float>()(nt, 19) = trk.etaError();
0080 input1.matrix<float>()(nt, 20) = trk.phiError();
0081 input1.matrix<float>()(nt, 21) = trk.hitPattern().numberOfValidPixelHits();
0082 input1.matrix<float>()(nt, 22) = trk.hitPattern().numberOfValidStripHits();
0083 input1.matrix<float>()(nt, 23) = trk.ndof();
0084 input1.matrix<float>()(nt, 24) =
0085 trk.hitPattern().numberOfLostTrackerHits(reco::HitPattern::MISSING_INNER_HITS);
0086 input1.matrix<float>()(nt, 25) =
0087 trk.hitPattern().numberOfLostTrackerHits(reco::HitPattern::MISSING_OUTER_HITS);
0088 input1.matrix<float>()(nt, 26) =
0089 trk.hitPattern().trackerLayersTotallyOffOrBad(reco::HitPattern::MISSING_INNER_HITS);
0090 input1.matrix<float>()(nt, 27) =
0091 trk.hitPattern().trackerLayersTotallyOffOrBad(reco::HitPattern::MISSING_OUTER_HITS);
0092 input1.matrix<float>()(nt, 28) =
0093 trk.hitPattern().trackerLayersWithoutMeasurement(reco::HitPattern::TRACK_HITS);
0094
0095
0096
0097 input2.matrix<float>()(nt, 0) = trk.originalAlgo();
0098 }
0099
0100
0101
0102 tensorflow::NamedTensorList inputs;
0103 inputs.resize(2);
0104 inputs[0] = tensorflow::NamedTensor("x", input1);
0105 inputs[1] = tensorflow::NamedTensor("y", input2);
0106 std::vector<tensorflow::Tensor> outputs;
0107
0108
0109 tensorflow::run(session_, inputs, {"Identity"}, &outputs);
0110
0111 for (auto nt = 0; nt < bsize_; nt++) {
0112 int itrack = nt + bsize_ * nb;
0113 if (itrack >= size_in)
0114 continue;
0115 float out0 = 2.0 * outputs[0].matrix<float>()(nt, 0) - 1.0;
0116 output[itrack] = out0;
0117 }
0118 }
0119 return output;
0120 }
0121
0122 const std::string tfDnnLabel_;
0123 const edm::ESGetToken<TfGraphDefWrapper, TfGraphRecord> tfDnnToken_;
0124 const tensorflow::Session* session_;
0125 const int bsize_;
0126 };
0127 }
0128
0129 template <>
0130 void trackMVAClassifierImpl::ComputeMVA<void>::operator()(::TfDnn const& mva,
0131 reco::TrackCollection const& tracks,
0132 reco::BeamSpot const& beamSpot,
0133 reco::VertexCollection const& vertices,
0134 TrackMVAClassifierBase::MVAPairCollection& mvas) {
0135 const auto& scores = mva(tracks, beamSpot, vertices);
0136 size_t current = 0;
0137
0138 for (auto score : scores) {
0139 std::pair<float, bool> output(score, true);
0140 mvas[current++] = output;
0141 }
0142 }
0143
0144 namespace {
0145 using TrackTfClassifier = TrackMVAClassifier<TfDnn>;
0146 }
0147
0148 #include "FWCore/PluginManager/interface/ModuleDef.h"
0149 #include "FWCore/Framework/interface/MakerMacros.h"
0150
0151 DEFINE_FWK_MODULE(TrackTfClassifier);