Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:28:09

0001 #include "RecoTracker/FinalTrackSelectors/interface/TrackMVAClassifier.h"
0002 
0003 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0004 
0005 #include "DataFormats/TrackReco/interface/Track.h"
0006 #include "DataFormats/VertexReco/interface/Vertex.h"
0007 
0008 #include <cassert>
0009 
0010 void TrackMVAClassifierBase::fill(edm::ParameterSetDescription& desc) {
0011   desc.add<edm::InputTag>("src", edm::InputTag());
0012   desc.add<edm::InputTag>("beamspot", edm::InputTag("offlineBeamSpot"));
0013   desc.add<edm::InputTag>("vertices", edm::InputTag("firstStepPrimaryVertices"));
0014   desc.add<bool>("ignoreVertices", false);
0015   // default cuts for "cut based classification"
0016   std::vector<double> cuts = {-.7, 0.1, .7};
0017   desc.add<std::vector<double>>("qualityCuts", cuts);
0018 }
0019 
0020 TrackMVAClassifierBase::~TrackMVAClassifierBase() {}
0021 
0022 TrackMVAClassifierBase::TrackMVAClassifierBase(const edm::ParameterSet& cfg)
0023     : src_(consumes<reco::TrackCollection>(cfg.getParameter<edm::InputTag>("src"))),
0024       beamspot_(consumes<reco::BeamSpot>(cfg.getParameter<edm::InputTag>("beamspot"))),
0025       vertices_(mayConsume<reco::VertexCollection>(cfg.getParameter<edm::InputTag>("vertices"))),
0026       ignoreVertices_(cfg.getParameter<bool>("ignoreVertices")) {
0027   auto const& qv = cfg.getParameter<std::vector<double>>("qualityCuts");
0028   assert(qv.size() == 3);
0029   std::copy(std::begin(qv), std::end(qv), std::begin(qualityCuts));
0030 
0031   produces<MVACollection>("MVAValues");
0032   produces<QualityMaskCollection>("QualityMasks");
0033 }
0034 
0035 void TrackMVAClassifierBase::produce(edm::Event& evt, const edm::EventSetup& es) {
0036   // Get tracks
0037   edm::Handle<reco::TrackCollection> hSrcTrack;
0038   evt.getByToken(src_, hSrcTrack);
0039   auto const& tracks(*hSrcTrack);
0040 
0041   // looking for the beam spot
0042   edm::Handle<reco::BeamSpot> hBsp;
0043   evt.getByToken(beamspot_, hBsp);
0044 
0045   // Select good primary vertices for use in subsequent track selection
0046   edm::Handle<reco::VertexCollection> hVtx;
0047   evt.getByToken(vertices_, hVtx);
0048 
0049   initEvent(es);
0050 
0051   // products
0052   auto mvaPairs = std::make_unique<MVAPairCollection>(tracks.size(), std::make_pair(-99.f, true));
0053   auto mvas = std::make_unique<MVACollection>(tracks.size(), -99.f);
0054   auto quals = std::make_unique<QualityMaskCollection>(tracks.size(), 0);
0055 
0056   if (hVtx.isValid() && !ignoreVertices_) {
0057     computeMVA(tracks, *hBsp, *hVtx, *mvaPairs);
0058   } else {
0059     if (!ignoreVertices_)
0060       edm::LogWarning("TrackMVAClassifierBase")
0061           << "ignoreVertices is set to False in the configuration, but the vertex collection is not valid";
0062     std::vector<reco::Vertex> vertices;
0063     computeMVA(tracks, *hBsp, vertices, *mvaPairs);
0064   }
0065   assert((*mvaPairs).size() == tracks.size());
0066 
0067   unsigned int k = 0;
0068   for (auto const& output : *mvaPairs) {
0069     if (output.second) {
0070       (*mvas)[k] = output.first;
0071     } else {
0072       // If the MVA value is known to be unreliable, force into generalTracks collection
0073       (*mvas)[k] = std::max(output.first, float(qualityCuts[0] + 0.001));
0074     }
0075     float mva = (*mvas)[k];
0076     (*quals)[k++] = (mva > qualityCuts[0]) << reco::TrackBase::loose |
0077                     (mva > qualityCuts[1]) << reco::TrackBase::tight |
0078                     (mva > qualityCuts[2]) << reco::TrackBase::highPurity;
0079   }
0080 
0081   evt.put(std::move(mvas), "MVAValues");
0082   evt.put(std::move(quals), "QualityMasks");
0083 }