File indexing completed on 2023-03-17 11:22:15
0001 #ifndef RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
0002 #define RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
0003
0004 #include "DataFormats/TrackReco/interface/TrackFwd.h"
0005 #include "DataFormats/VertexReco/interface/VertexFwd.h"
0006 #include "DataFormats/BeamSpot/interface/BeamSpot.h"
0007
0008 #include "FWCore/Framework/interface/stream/EDProducer.h"
0009 #include "FWCore/Framework/interface/Event.h"
0010 #include "FWCore/Framework/interface/ConsumesCollector.h"
0011 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0012 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0013 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0014
0015 #include "FWCore/Utilities/interface/InputTag.h"
0016
0017 #include "CondFormats/GBRForest/interface/GBRForest.h"
0018
0019 #include <vector>
0020 #include <memory>
0021
0022 class TrackMVAClassifierBase : public edm::stream::EDProducer<> {
0023 public:
0024 explicit TrackMVAClassifierBase(const edm::ParameterSet& cfg);
0025 ~TrackMVAClassifierBase() override;
0026
0027 using MVACollection = std::vector<float>;
0028 using QualityMaskCollection = std::vector<unsigned char>;
0029
0030
0031 using MVAPairCollection = std::vector<std::pair<float, bool>>;
0032
0033 protected:
0034 static void fill(edm::ParameterSetDescription& desc);
0035
0036 virtual void initEvent(const edm::EventSetup& es) = 0;
0037
0038 virtual void computeMVA(reco::TrackCollection const& tracks,
0039 reco::BeamSpot const& beamSpot,
0040 reco::VertexCollection const& vertices,
0041 MVAPairCollection& mvas) const = 0;
0042
0043 private:
0044 void produce(edm::Event& evt, const edm::EventSetup& es) final;
0045
0046
0047 edm::EDGetTokenT<reco::TrackCollection> src_;
0048 edm::EDGetTokenT<reco::BeamSpot> beamspot_;
0049 edm::EDGetTokenT<reco::VertexCollection> vertices_;
0050
0051 bool ignoreVertices_;
0052
0053
0054
0055
0056 float qualityCuts[3];
0057 };
0058
0059 namespace trackMVAClassifierImpl {
0060 template <typename EventCache>
0061 struct ComputeMVA {
0062 template <typename MVA>
0063 void operator()(MVA const& mva,
0064 reco::TrackCollection const& tracks,
0065 reco::BeamSpot const& beamSpot,
0066 reco::VertexCollection const& vertices,
0067 TrackMVAClassifierBase::MVAPairCollection& mvas) {
0068 EventCache cache;
0069
0070 size_t current = 0;
0071 for (auto const& trk : tracks) {
0072 mvas[current++] = mva(trk, beamSpot, vertices, cache);
0073 }
0074 }
0075 };
0076
0077 template <>
0078 struct ComputeMVA<void> {
0079 template <typename MVA>
0080 void operator()(MVA const& mva,
0081 reco::TrackCollection const& tracks,
0082 reco::BeamSpot const& beamSpot,
0083 reco::VertexCollection const& vertices,
0084 TrackMVAClassifierBase::MVAPairCollection& mvas) {
0085 size_t current = 0;
0086 for (auto const& trk : tracks) {
0087
0088 std::pair<float, bool> output(mva(trk, beamSpot, vertices), true);
0089 mvas[current++] = output;
0090 }
0091 }
0092 };
0093 }
0094
0095 template <typename MVA, typename EventCache = void>
0096 class TrackMVAClassifier : public TrackMVAClassifierBase {
0097 public:
0098 explicit TrackMVAClassifier(const edm::ParameterSet& cfg)
0099 : TrackMVAClassifierBase(cfg), mva(cfg.getParameter<edm::ParameterSet>("mva"), consumesCollector()) {}
0100
0101 static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0102 edm::ParameterSetDescription desc;
0103 fill(desc);
0104 edm::ParameterSetDescription mvaDesc;
0105 MVA::fillDescriptions(mvaDesc);
0106 desc.add<edm::ParameterSetDescription>("mva", mvaDesc);
0107 descriptions.add(MVA::name(), desc);
0108 }
0109
0110 private:
0111 void beginStream(edm::StreamID) final { mva.beginStream(); }
0112
0113 void initEvent(const edm::EventSetup& es) final { mva.initEvent(es); }
0114
0115 void computeMVA(reco::TrackCollection const& tracks,
0116 reco::BeamSpot const& beamSpot,
0117 reco::VertexCollection const& vertices,
0118 MVAPairCollection& mvas) const final {
0119 trackMVAClassifierImpl::ComputeMVA<EventCache> computer;
0120 computer(mva, tracks, beamSpot, vertices, mvas);
0121 }
0122
0123 MVA mva;
0124 };
0125
0126 #endif