Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:23:58

0001 #include "PhysicsTools/PatAlgos/interface/MuonMvaIDEstimator.h"
0002 #include "FWCore/Framework/interface/Frameworkfwd.h"
0003 #include "FWCore/Framework/interface/MakerMacros.h"
0004 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0005 #include "FWCore/ParameterSet/interface/FileInPath.h"
0006 #include "DataFormats/MuonReco/interface/Muon.h"
0007 #include "DataFormats/MuonReco/interface/MuonSelectors.h"
0008 #include "DataFormats/PatCandidates/interface/Muon.h"
0009 
0010 using namespace pat;
0011 using namespace cms::Ort;
0012 
0013 MuonMvaIDEstimator::MuonMvaIDEstimator(const edm::FileInPath &weightsfile) {
0014   randomForest_ = std::make_unique<ONNXRuntime>(weightsfile.fullPath());
0015   LogDebug("MuonMvaIDEstimator") << randomForest_.get();
0016 }
0017 
0018 void MuonMvaIDEstimator::fillDescriptions(edm::ConfigurationDescriptions &descriptions) {
0019   edm::ParameterSetDescription desc;
0020   desc.add<edm::FileInPath>("mvaIDTrainingFile", edm::FileInPath("RecoMuon/MuonIdentification/data/mvaID.onnx"));
0021   desc.add<std::vector<std::string>>("flav_names",
0022                                      std::vector<std::string>{
0023                                          "probBAD",
0024                                          "probGOOD",
0025                                      });
0026 
0027   descriptions.addWithDefaultLabel(desc);
0028 }
0029 
0030 void MuonMvaIDEstimator::globalEndJob(const cms::Ort::ONNXRuntime *cache) {}
0031 const reco::Muon::ArbitrationType arbitrationType = reco::Muon::SegmentAndTrackArbitration;
0032 std::vector<float> MuonMvaIDEstimator::computeMVAID(const pat::Muon &muon) const {
0033   const float local_chi2 = muon.combinedQuality().chi2LocalPosition;
0034   const float kink = muon.combinedQuality().trkKink;
0035   const float segment_comp = muon.segmentCompatibility(arbitrationType);
0036   const float n_MatchedStations = muon.numberOfMatchedStations();
0037   const float pt = muon.pt();
0038   const float eta = muon.eta();
0039   const float global_muon = muon.isGlobalMuon();
0040   float Valid_pixel;
0041   float tracker_layers;
0042   float validFraction;
0043   if (muon.innerTrack().isNonnull()) {
0044     Valid_pixel = muon.innerTrack()->hitPattern().numberOfValidPixelHits();
0045     tracker_layers = muon.innerTrack()->hitPattern().trackerLayersWithMeasurement();
0046     validFraction = muon.innerTrack()->validFraction();
0047   } else {
0048     Valid_pixel = -99.;
0049     tracker_layers = -99.0;
0050     validFraction = -99.0;
0051   }
0052   float norm_chi2;
0053   float n_Valid_hits;
0054   if (muon.globalTrack().isNonnull()) {
0055     norm_chi2 = muon.globalTrack()->normalizedChi2();
0056     n_Valid_hits = muon.globalTrack()->hitPattern().numberOfValidMuonHits();
0057   } else if (muon.innerTrack().isNonnull()) {
0058     norm_chi2 = muon.innerTrack()->normalizedChi2();
0059     n_Valid_hits = muon.innerTrack()->hitPattern().numberOfValidMuonHits();
0060   } else {
0061     norm_chi2 = -99;
0062     n_Valid_hits = -99;
0063   }
0064   const std::vector<std::string> input_names_{"float_input"};
0065   std::vector<float> vars = {global_muon,
0066                              validFraction,
0067                              norm_chi2,
0068                              local_chi2,
0069                              kink,
0070                              segment_comp,
0071                              n_Valid_hits,
0072                              n_MatchedStations,
0073                              Valid_pixel,
0074                              tracker_layers,
0075                              pt,
0076                              eta};
0077   const std::vector<std::string> flav_names_{"probBAD", "probGOOD"};
0078   cms::Ort::FloatArrays input_values_;
0079   input_values_.emplace_back(vars);
0080   std::vector<float> outputs;
0081   LogDebug("MuonMvaIDEstimator") << randomForest_.get();
0082   outputs = randomForest_->run(input_names_, input_values_, {}, {"probabilities"})[0];
0083   assert(outputs.size() == flav_names_.size());
0084   return outputs;
0085 }