Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-03-17 11:22:15

0001 #ifndef RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
0002 #define RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
0003 
0004 #include "DataFormats/TrackReco/interface/TrackFwd.h"
0005 #include "DataFormats/VertexReco/interface/VertexFwd.h"
0006 #include "DataFormats/BeamSpot/interface/BeamSpot.h"
0007 
0008 #include "FWCore/Framework/interface/stream/EDProducer.h"
0009 #include "FWCore/Framework/interface/Event.h"
0010 #include "FWCore/Framework/interface/ConsumesCollector.h"
0011 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0012 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0013 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0014 
0015 #include "FWCore/Utilities/interface/InputTag.h"
0016 
0017 #include "CondFormats/GBRForest/interface/GBRForest.h"
0018 
0019 #include <vector>
0020 #include <memory>
0021 
0022 class TrackMVAClassifierBase : public edm::stream::EDProducer<> {
0023 public:
0024   explicit TrackMVAClassifierBase(const edm::ParameterSet& cfg);
0025   ~TrackMVAClassifierBase() override;
0026 
0027   using MVACollection = std::vector<float>;
0028   using QualityMaskCollection = std::vector<unsigned char>;
0029 
0030   //Collection with pairs <MVAOutput, isReliable>
0031   using MVAPairCollection = std::vector<std::pair<float, bool>>;
0032 
0033 protected:
0034   static void fill(edm::ParameterSetDescription& desc);
0035 
0036   virtual void initEvent(const edm::EventSetup& es) = 0;
0037 
0038   virtual void computeMVA(reco::TrackCollection const& tracks,
0039                           reco::BeamSpot const& beamSpot,
0040                           reco::VertexCollection const& vertices,
0041                           MVAPairCollection& mvas) const = 0;
0042 
0043 private:
0044   void produce(edm::Event& evt, const edm::EventSetup& es) final;
0045 
0046   /// source collection label
0047   edm::EDGetTokenT<reco::TrackCollection> src_;
0048   edm::EDGetTokenT<reco::BeamSpot> beamspot_;
0049   edm::EDGetTokenT<reco::VertexCollection> vertices_;
0050 
0051   bool ignoreVertices_;
0052 
0053   // MVA
0054 
0055   // qualitycuts (loose, tight, hp)
0056   float qualityCuts[3];
0057 };
0058 
0059 namespace trackMVAClassifierImpl {
0060   template <typename EventCache>
0061   struct ComputeMVA {
0062     template <typename MVA>
0063     void operator()(MVA const& mva,
0064                     reco::TrackCollection const& tracks,
0065                     reco::BeamSpot const& beamSpot,
0066                     reco::VertexCollection const& vertices,
0067                     TrackMVAClassifierBase::MVAPairCollection& mvas) {
0068       EventCache cache;
0069 
0070       size_t current = 0;
0071       for (auto const& trk : tracks) {
0072         mvas[current++] = mva(trk, beamSpot, vertices, cache);
0073       }
0074     }
0075   };
0076 
0077   template <>
0078   struct ComputeMVA<void> {
0079     template <typename MVA>
0080     void operator()(MVA const& mva,
0081                     reco::TrackCollection const& tracks,
0082                     reco::BeamSpot const& beamSpot,
0083                     reco::VertexCollection const& vertices,
0084                     TrackMVAClassifierBase::MVAPairCollection& mvas) {
0085       size_t current = 0;
0086       for (auto const& trk : tracks) {
0087         //BDT outputs are considered always reliable. Hence "true"
0088         std::pair<float, bool> output(mva(trk, beamSpot, vertices), true);
0089         mvas[current++] = output;
0090       }
0091     }
0092   };
0093 }  // namespace trackMVAClassifierImpl
0094 
0095 template <typename MVA, typename EventCache = void>
0096 class TrackMVAClassifier : public TrackMVAClassifierBase {
0097 public:
0098   explicit TrackMVAClassifier(const edm::ParameterSet& cfg)
0099       : TrackMVAClassifierBase(cfg), mva(cfg.getParameter<edm::ParameterSet>("mva"), consumesCollector()) {}
0100 
0101   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0102     edm::ParameterSetDescription desc;
0103     fill(desc);
0104     edm::ParameterSetDescription mvaDesc;
0105     MVA::fillDescriptions(mvaDesc);
0106     desc.add<edm::ParameterSetDescription>("mva", mvaDesc);
0107     descriptions.add(MVA::name(), desc);
0108   }
0109 
0110 private:
0111   void beginStream(edm::StreamID) final { mva.beginStream(); }
0112 
0113   void initEvent(const edm::EventSetup& es) final { mva.initEvent(es); }
0114 
0115   void computeMVA(reco::TrackCollection const& tracks,
0116                   reco::BeamSpot const& beamSpot,
0117                   reco::VertexCollection const& vertices,
0118                   MVAPairCollection& mvas) const final {
0119     trackMVAClassifierImpl::ComputeMVA<EventCache> computer;
0120     computer(mva, tracks, beamSpot, vertices, mvas);
0121   }
0122 
0123   MVA mva;
0124 };
0125 
0126 #endif  //  RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h