File indexing completed on 2024-07-16 22:52:37
0001 #ifndef RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
0002 #define RecoTracker_FinalTrackSelectors_TrackMVAClassifierBase_h
0003
0004 #include "DataFormats/TrackReco/interface/Track.h"
0005 #include "DataFormats/TrackReco/interface/TrackFwd.h"
0006 #include "DataFormats/VertexReco/interface/VertexFwd.h"
0007 #include "DataFormats/BeamSpot/interface/BeamSpot.h"
0008
0009 #include "FWCore/Framework/interface/stream/EDProducer.h"
0010 #include "FWCore/Framework/interface/Event.h"
0011 #include "FWCore/Framework/interface/ConsumesCollector.h"
0012 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0013 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0014 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0015
0016 #include "FWCore/Utilities/interface/InputTag.h"
0017
0018 #include "CondFormats/GBRForest/interface/GBRForest.h"
0019
0020 #include <vector>
0021 #include <memory>
0022
0023 class TrackMVAClassifierBase : public edm::stream::EDProducer<> {
0024 public:
0025 explicit TrackMVAClassifierBase(const edm::ParameterSet& cfg);
0026 ~TrackMVAClassifierBase() override;
0027
0028 using MVACollection = std::vector<float>;
0029 using QualityMaskCollection = std::vector<unsigned char>;
0030
0031
0032 using MVAPairCollection = std::vector<std::pair<float, bool>>;
0033
0034 protected:
0035 static void fill(edm::ParameterSetDescription& desc);
0036
0037 virtual void initEvent(const edm::EventSetup& es) = 0;
0038
0039 virtual void computeMVA(reco::TrackCollection const& tracks,
0040 reco::BeamSpot const& beamSpot,
0041 reco::VertexCollection const& vertices,
0042 MVAPairCollection& mvas) const = 0;
0043
0044 private:
0045 void produce(edm::Event& evt, const edm::EventSetup& es) final;
0046
0047
0048 edm::EDGetTokenT<reco::TrackCollection> src_;
0049 edm::EDGetTokenT<reco::BeamSpot> beamspot_;
0050 edm::EDGetTokenT<reco::VertexCollection> vertices_;
0051
0052 bool ignoreVertices_;
0053
0054
0055
0056
0057 float qualityCuts[3];
0058 };
0059
0060 namespace trackMVAClassifierImpl {
0061 template <typename EventCache>
0062 struct ComputeMVA {
0063 template <typename MVA>
0064 void operator()(MVA const& mva,
0065 reco::TrackCollection const& tracks,
0066 reco::BeamSpot const& beamSpot,
0067 reco::VertexCollection const& vertices,
0068 TrackMVAClassifierBase::MVAPairCollection& mvas) {
0069 EventCache cache;
0070
0071 size_t current = 0;
0072 for (auto const& trk : tracks) {
0073 mvas[current++] = mva(trk, beamSpot, vertices, cache);
0074 }
0075 }
0076 };
0077
0078 template <>
0079 struct ComputeMVA<void> {
0080 template <typename MVA>
0081 void operator()(MVA const& mva,
0082 reco::TrackCollection const& tracks,
0083 reco::BeamSpot const& beamSpot,
0084 reco::VertexCollection const& vertices,
0085 TrackMVAClassifierBase::MVAPairCollection& mvas) {
0086 size_t current = 0;
0087 for (auto const& trk : tracks) {
0088
0089 std::pair<float, bool> output(mva(trk, beamSpot, vertices), true);
0090 mvas[current++] = output;
0091 }
0092 }
0093 };
0094 }
0095
0096 template <typename MVA, typename EventCache = void>
0097 class TrackMVAClassifier : public TrackMVAClassifierBase {
0098 public:
0099 explicit TrackMVAClassifier(const edm::ParameterSet& cfg)
0100 : TrackMVAClassifierBase(cfg), mva(cfg.getParameter<edm::ParameterSet>("mva"), consumesCollector()) {}
0101
0102 static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0103 edm::ParameterSetDescription desc;
0104 fill(desc);
0105 edm::ParameterSetDescription mvaDesc;
0106 MVA::fillDescriptions(mvaDesc);
0107 desc.add<edm::ParameterSetDescription>("mva", mvaDesc);
0108 descriptions.add(MVA::name(), desc);
0109 }
0110
0111 private:
0112 void beginStream(edm::StreamID) final { mva.beginStream(); }
0113
0114 void initEvent(const edm::EventSetup& es) final { mva.initEvent(es); }
0115
0116 void computeMVA(reco::TrackCollection const& tracks,
0117 reco::BeamSpot const& beamSpot,
0118 reco::VertexCollection const& vertices,
0119 MVAPairCollection& mvas) const final {
0120 trackMVAClassifierImpl::ComputeMVA<EventCache> computer;
0121 computer(mva, tracks, beamSpot, vertices, mvas);
0122 }
0123
0124 MVA mva;
0125 };
0126
0127 #endif