Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-07-16 22:52:37

0001 #ifndef RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
0002 #define RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
0003 
0004 #include "DataFormats/TrackReco/interface/Track.h"
0005 #include "DataFormats/TrackReco/interface/TrackFwd.h"
0006 #include "DataFormats/VertexReco/interface/VertexFwd.h"
0007 #include "DataFormats/BeamSpot/interface/BeamSpot.h"
0008 
0009 #include "FWCore/Framework/interface/stream/EDProducer.h"
0010 #include "FWCore/Framework/interface/Event.h"
0011 #include "FWCore/Framework/interface/ConsumesCollector.h"
0012 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0013 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0014 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0015 
0016 #include "FWCore/Utilities/interface/InputTag.h"
0017 
0018 #include "CondFormats/GBRForest/interface/GBRForest.h"
0019 
0020 #include <vector>
0021 #include <memory>
0022 
0023 class TrackMVAClassifierBase : public edm::stream::EDProducer<> {
0024 public:
0025   explicit TrackMVAClassifierBase(const edm::ParameterSet& cfg);
0026   ~TrackMVAClassifierBase() override;
0027 
0028   using MVACollection = std::vector<float>;
0029   using QualityMaskCollection = std::vector<unsigned char>;
0030 
0031   //Collection with pairs <MVAOutput, isReliable>
0032   using MVAPairCollection = std::vector<std::pair<float, bool>>;
0033 
0034 protected:
0035   static void fill(edm::ParameterSetDescription& desc);
0036 
0037   virtual void initEvent(const edm::EventSetup& es) = 0;
0038 
0039   virtual void computeMVA(reco::TrackCollection const& tracks,
0040                           reco::BeamSpot const& beamSpot,
0041                           reco::VertexCollection const& vertices,
0042                           MVAPairCollection& mvas) const = 0;
0043 
0044 private:
0045   void produce(edm::Event& evt, const edm::EventSetup& es) final;
0046 
0047   /// source collection label
0048   edm::EDGetTokenT<reco::TrackCollection> src_;
0049   edm::EDGetTokenT<reco::BeamSpot> beamspot_;
0050   edm::EDGetTokenT<reco::VertexCollection> vertices_;
0051 
0052   bool ignoreVertices_;
0053 
0054   // MVA
0055 
0056   // qualitycuts (loose, tight, hp)
0057   float qualityCuts[3];
0058 };
0059 
0060 namespace trackMVAClassifierImpl {
0061   template <typename EventCache>
0062   struct ComputeMVA {
0063     template <typename MVA>
0064     void operator()(MVA const& mva,
0065                     reco::TrackCollection const& tracks,
0066                     reco::BeamSpot const& beamSpot,
0067                     reco::VertexCollection const& vertices,
0068                     TrackMVAClassifierBase::MVAPairCollection& mvas) {
0069       EventCache cache;
0070 
0071       size_t current = 0;
0072       for (auto const& trk : tracks) {
0073         mvas[current++] = mva(trk, beamSpot, vertices, cache);
0074       }
0075     }
0076   };
0077 
0078   template <>
0079   struct ComputeMVA<void> {
0080     template <typename MVA>
0081     void operator()(MVA const& mva,
0082                     reco::TrackCollection const& tracks,
0083                     reco::BeamSpot const& beamSpot,
0084                     reco::VertexCollection const& vertices,
0085                     TrackMVAClassifierBase::MVAPairCollection& mvas) {
0086       size_t current = 0;
0087       for (auto const& trk : tracks) {
0088         //BDT outputs are considered always reliable. Hence "true"
0089         std::pair<float, bool> output(mva(trk, beamSpot, vertices), true);
0090         mvas[current++] = output;
0091       }
0092     }
0093   };
0094 }  // namespace trackMVAClassifierImpl
0095 
0096 template <typename MVA, typename EventCache = void>
0097 class TrackMVAClassifier : public TrackMVAClassifierBase {
0098 public:
0099   explicit TrackMVAClassifier(const edm::ParameterSet& cfg)
0100       : TrackMVAClassifierBase(cfg), mva(cfg.getParameter<edm::ParameterSet>("mva"), consumesCollector()) {}
0101 
0102   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0103     edm::ParameterSetDescription desc;
0104     fill(desc);
0105     edm::ParameterSetDescription mvaDesc;
0106     MVA::fillDescriptions(mvaDesc);
0107     desc.add<edm::ParameterSetDescription>("mva", mvaDesc);
0108     descriptions.add(MVA::name(), desc);
0109   }
0110 
0111 private:
0112   void beginStream(edm::StreamID) final { mva.beginStream(); }
0113 
0114   void initEvent(const edm::EventSetup& es) final { mva.initEvent(es); }
0115 
0116   void computeMVA(reco::TrackCollection const& tracks,
0117                   reco::BeamSpot const& beamSpot,
0118                   reco::VertexCollection const& vertices,
0119                   MVAPairCollection& mvas) const final {
0120     trackMVAClassifierImpl::ComputeMVA<EventCache> computer;
0121     computer(mva, tracks, beamSpot, vertices, mvas);
0122   }
0123 
0124   MVA mva;
0125 };
0126 
0127 #endif  //  RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h