Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:49

0001 #include "CommonTools/BaseParticlePropagator/interface/BaseParticlePropagator.h"
0002 #include "DataFormats/Common/interface/Handle.h"
0003 #include "DataFormats/Common/interface/Ptr.h"
0004 #include "DataFormats/Common/interface/ValueMap.h"
0005 #include "DataFormats/Common/interface/View.h"
0006 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0007 #include "FWCore/Utilities/interface/InputTag.h"
0008 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0009 #include "FWCore/Framework/interface/ESHandle.h"
0010 #include "FWCore/Framework/interface/global/EDProducer.h"
0011 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0012 #include "CommonTools/MVAUtils/interface/GBRForestTools.h"
0013 #include "DataFormats/EgammaCandidates/interface/GsfElectron.h"
0014 #include "DataFormats/PatCandidates/interface/Electron.h"
0015 #include "DataFormats/PatCandidates/interface/PackedCandidate.h"
0016 #include "DataFormats/EgammaReco/interface/SuperCluster.h"
0017 #include "DataFormats/GsfTrackReco/interface/GsfTrack.h"
0018 #include "DataFormats/TrackReco/interface/Track.h"
0019 #include "DataFormats/Math/interface/LorentzVector.h"
0020 #include "FWCore/Framework/interface/Event.h"
0021 #include "MagneticField/Engine/interface/MagneticField.h"
0022 #include "MagneticField/Records/interface/IdealMagneticFieldRecord.h"
0023 #include "RecoEgamma/EgammaElectronProducers/interface/LowPtGsfElectronFeatures.h"
0024 #include <string>
0025 #include <vector>
0026 
0027 ////////////////////////////////////////////////////////////////////////////////
0028 //
0029 class LowPtGsfElectronIDProducer final : public edm::global::EDProducer<> {
0030 public:
0031   explicit LowPtGsfElectronIDProducer(const edm::ParameterSet&);
0032 
0033   void produce(edm::StreamID, edm::Event&, const edm::EventSetup&) const override;
0034 
0035   static void fillDescriptions(edm::ConfigurationDescriptions&);
0036 
0037 private:
0038   double eval(const GBRForest& model,
0039               const reco::GsfElectron&,
0040               double rho,
0041               float unbiased,
0042               float field_z,
0043               const reco::Track* trk = nullptr) const;
0044 
0045   template <typename EL, typename FID, typename FTRK>
0046   void doWork(double rho, float bz, EL const& electrons, FID&& idFunctor, FTRK&& trkFunctor, edm::Event&) const;
0047   const bool useGsfToTrack_;
0048   const bool usePAT_;
0049   edm::EDGetTokenT<reco::GsfElectronCollection> electrons_;
0050   edm::EDGetTokenT<pat::ElectronCollection> patElectrons_;
0051   const edm::EDGetTokenT<double> rho_;
0052   edm::EDGetTokenT<edm::Association<reco::TrackCollection>> gsf2trk_;
0053   edm::EDGetTokenT<edm::ValueMap<float>> unbiased_;
0054   const edm::ESGetToken<MagneticField, IdealMagneticFieldRecord> fieldToken_;
0055 
0056   const std::vector<std::string> names_;
0057   std::vector<edm::EDPutTokenT<edm::ValueMap<float>>> putTokens_;
0058   const bool passThrough_;
0059   const double minPtThreshold_;
0060   const double maxPtThreshold_;
0061   std::vector<std::unique_ptr<const GBRForest>> models_;
0062   const std::vector<double> thresholds_;
0063   const std::string versionName_;
0064   enum class Version { V0, V1 };
0065   Version version_;
0066 };
0067 
0068 ////////////////////////////////////////////////////////////////////////////////
0069 //
0070 LowPtGsfElectronIDProducer::LowPtGsfElectronIDProducer(const edm::ParameterSet& conf)
0071     : useGsfToTrack_(conf.getParameter<bool>("useGsfToTrack")),
0072       usePAT_(conf.getParameter<bool>("usePAT")),
0073       electrons_(),
0074       patElectrons_(),
0075       rho_(consumes<double>(conf.getParameter<edm::InputTag>("rho"))),
0076       gsf2trk_(),
0077       unbiased_(),
0078       fieldToken_(esConsumes()),
0079       names_(conf.getParameter<std::vector<std::string>>("ModelNames")),
0080       passThrough_(conf.getParameter<bool>("PassThrough")),
0081       minPtThreshold_(conf.getParameter<double>("MinPtThreshold")),
0082       maxPtThreshold_(conf.getParameter<double>("MaxPtThreshold")),
0083       thresholds_(conf.getParameter<std::vector<double>>("ModelThresholds")),
0084       versionName_(conf.getParameter<std::string>("Version")) {
0085   if (useGsfToTrack_) {
0086     gsf2trk_ = consumes<edm::Association<reco::TrackCollection>>(conf.getParameter<edm::InputTag>("gsfToTrack"));
0087   }
0088   if (usePAT_) {
0089     patElectrons_ = consumes<pat::ElectronCollection>(conf.getParameter<edm::InputTag>("electrons"));
0090   } else {
0091     electrons_ = consumes<reco::GsfElectronCollection>(conf.getParameter<edm::InputTag>("electrons"));
0092     unbiased_ = consumes<edm::ValueMap<float>>(conf.getParameter<edm::InputTag>("unbiased"));
0093   }
0094   for (auto& weights : conf.getParameter<std::vector<std::string>>("ModelWeights")) {
0095     models_.push_back(createGBRForest(edm::FileInPath(weights)));
0096   }
0097   if (names_.size() != models_.size()) {
0098     throw cms::Exception("Incorrect configuration")
0099         << "'ModelNames' size (" << names_.size() << ") != 'ModelWeights' size (" << models_.size() << ").\n";
0100   }
0101   if (models_.size() != thresholds_.size()) {
0102     throw cms::Exception("Incorrect configuration")
0103         << "'ModelWeights' size (" << models_.size() << ") != 'ModelThresholds' size (" << thresholds_.size() << ").\n";
0104   }
0105   if (versionName_ == "V0") {
0106     version_ = Version::V0;
0107   } else if (versionName_ == "V1") {
0108     version_ = Version::V1;
0109   } else {
0110     throw cms::Exception("Incorrect configuration") << "Unknown Version: " << versionName_ << "\n";
0111   }
0112   putTokens_.reserve(names_.size());
0113   for (const auto& name : names_) {
0114     putTokens_.emplace_back(produces<edm::ValueMap<float>>(name));
0115   }
0116 }
0117 
0118 ////////////////////////////////////////////////////////////////////////////////
0119 //
0120 void LowPtGsfElectronIDProducer::produce(edm::StreamID, edm::Event& event, const edm::EventSetup& setup) const {
0121   // Get z-component of B field
0122   math::XYZVector zfield(setup.getData(fieldToken_).inTesla(GlobalPoint(0, 0, 0)));
0123 
0124   // Pileup
0125   edm::Handle<double> hRho;
0126   event.getByToken(rho_, hRho);
0127   if (!hRho.isValid()) {
0128     std::ostringstream os;
0129     os << "Problem accessing rho collection for low-pT electrons" << std::endl;
0130     throw cms::Exception("InvalidHandle", os.str());
0131   }
0132 
0133   // Retrieve GsfToTrack Association from Event
0134   edm::Handle<edm::Association<reco::TrackCollection>> gsf2trk;
0135   if (useGsfToTrack_) {
0136     event.getByToken(gsf2trk_, gsf2trk);
0137   }
0138 
0139   double rho = *hRho;
0140   // Retrieve pat::Electrons or reco::GsfElectrons from Event
0141   edm::Handle<pat::ElectronCollection> patElectrons;
0142   edm::Handle<reco::GsfElectronCollection> electrons;
0143   if (usePAT_) {
0144     auto const& electrons = event.getHandle(patElectrons_);
0145 
0146     const std::string kUnbiased("unbiased");
0147     doWork(
0148         rho,
0149         zfield.z(),
0150         electrons,
0151         [&](auto const& ele) {
0152           if (!ele.isElectronIDAvailable(kUnbiased)) {
0153             return std::numeric_limits<float>::max();
0154           }
0155           return ele.electronID(kUnbiased);
0156         },
0157         [&](auto const& ele) {  // trkFunctor ...
0158           if (useGsfToTrack_) {
0159             using PackedPtr = edm::Ptr<pat::PackedCandidate>;
0160             const PackedPtr* ptr1 = ele.template userData<PackedPtr>("ele2packed");
0161             const PackedPtr* ptr2 = ele.template userData<PackedPtr>("ele2lost");
0162             auto hasBestTrack = [](const PackedPtr* ptr) {
0163               return ptr != nullptr && ptr->isNonnull() && ptr->isAvailable() && ptr->get() != nullptr &&
0164                      ptr->get()->bestTrack() != nullptr;
0165             };
0166             if (hasBestTrack(ptr1)) {
0167               return ptr1->get()->bestTrack();
0168             } else if (hasBestTrack(ptr2)) {
0169               return ptr2->get()->bestTrack();
0170             }
0171           } else {
0172             reco::TrackRef ref = ele.closestCtfTrackRef();
0173             if (ref.isNonnull() && ref.isAvailable()) {
0174               return ref.get();
0175             }
0176           }
0177           return static_cast<const reco::Track*>(nullptr);
0178         },
0179         event);
0180   } else {
0181     auto const& electrons = event.getHandle(electrons_);
0182     // ElectronSeed unbiased BDT
0183     edm::ValueMap<float> const& unbiasedH = event.get(unbiased_);
0184     doWork(
0185         rho,
0186         zfield.z(),
0187         electrons,
0188         [&](auto const& ele) {
0189           if (ele.core().isNull()) {
0190             return std::numeric_limits<float>::max();
0191           }
0192           const auto& gsf = ele.core()->gsfTrack();  // reco::GsfTrackRef
0193           if (gsf.isNull()) {
0194             return std::numeric_limits<float>::max();
0195           }
0196           return unbiasedH[gsf];
0197         },
0198         [&](auto const& ele) {  // trkFunctor ...
0199           if (useGsfToTrack_) {
0200             const auto& gsf = ele.core()->gsfTrack();
0201             if (gsf.isNonnull() && gsf.isAvailable()) {
0202               auto const& ref = (*gsf2trk)[gsf];
0203               if (ref.isNonnull() && ref.isAvailable()) {
0204                 return ref.get();
0205               }
0206             }
0207           } else {
0208             reco::TrackRef ref = ele.closestCtfTrackRef();
0209             if (ref.isNonnull() && ref.isAvailable()) {
0210               return ref.get();
0211             }
0212           }
0213           return static_cast<const reco::Track*>(nullptr);
0214         },
0215         event);
0216   }
0217 }
0218 
0219 template <typename EL, typename FID, typename FTRK>
0220 void LowPtGsfElectronIDProducer::doWork(
0221     double rho, float bz, EL const& electrons, FID&& idFunctor, FTRK&& trkFunctor, edm::Event& event) const {
0222   auto nElectrons = electrons->size();
0223   std::vector<float> ids;
0224   ids.reserve(nElectrons);
0225   std::transform(electrons->begin(), electrons->end(), std::back_inserter(ids), idFunctor);
0226   std::vector<const reco::Track*> trks;
0227   trks.reserve(nElectrons);
0228   std::transform(electrons->begin(), electrons->end(), std::back_inserter(trks), trkFunctor);
0229 
0230   std::vector<float> output(nElectrons);  //resused for each model
0231   for (unsigned int index = 0; index < names_.size(); ++index) {
0232     // Iterate through Electrons, evaluate BDT, and store result
0233     for (unsigned int iele = 0; iele < nElectrons; iele++) {
0234       auto const& ele = (*electrons)[iele];
0235       if (ids[iele] != std::numeric_limits<float>::max()) {
0236         output[iele] = eval(*models_[index], ele, rho, ids[iele], bz, trks[iele]);
0237       } else {
0238         output[iele] = -999.;
0239       }
0240     }
0241     edm::ValueMap<float> valueMap;
0242     edm::ValueMap<float>::Filler filler(valueMap);
0243     filler.insert(electrons, output.begin(), output.end());
0244     filler.fill();
0245     event.emplace(putTokens_[index], std::move(valueMap));
0246   }
0247 }
0248 
0249 //////////////////////////////////////////////////////////////////////////////////////////
0250 //
0251 double LowPtGsfElectronIDProducer::eval(const GBRForest& model,
0252                                         const reco::GsfElectron& ele,
0253                                         double rho,
0254                                         float unbiased,
0255                                         float field_z,
0256                                         const reco::Track* trk) const {
0257   std::vector<float> inputs;
0258   if (version_ == Version::V0) {
0259     inputs = lowptgsfeleid::features_V0(ele, rho, unbiased);
0260   } else if (version_ == Version::V1) {
0261     inputs = lowptgsfeleid::features_V1(ele, rho, unbiased, field_z, trk);
0262   }
0263   return model.GetResponse(inputs.data());
0264 }
0265 
0266 //////////////////////////////////////////////////////////////////////////////////////////
0267 //
0268 void LowPtGsfElectronIDProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0269   edm::ParameterSetDescription desc;
0270   desc.add<bool>("useGsfToTrack", false);
0271   desc.add<bool>("usePAT", false);
0272   desc.add<edm::InputTag>("electrons", edm::InputTag("lowPtGsfElectrons"));
0273   desc.addOptional<edm::InputTag>("gsfToTrack", edm::InputTag("lowPtGsfToTrackLinks"));
0274   desc.addOptional<edm::InputTag>("unbiased", edm::InputTag("lowPtGsfElectronSeedValueMaps:unbiased"));
0275   desc.add<edm::InputTag>("rho", edm::InputTag("fixedGridRhoFastjetAll"));
0276   desc.add<std::vector<std::string>>("ModelNames", {""});
0277   desc.add<std::vector<std::string>>(
0278       "ModelWeights", {"RecoEgamma/ElectronIdentification/data/LowPtElectrons/LowPtElectrons_ID_2020Nov28.root"});
0279   desc.add<std::vector<double>>("ModelThresholds", {-99.});
0280   desc.add<bool>("PassThrough", false);
0281   desc.add<double>("MinPtThreshold", 0.5);
0282   desc.add<double>("MaxPtThreshold", 15.);
0283   desc.add<std::string>("Version", "V1");
0284   descriptions.add("defaultLowPtGsfElectronID", desc);
0285 }
0286 
0287 //////////////////////////////////////////////////////////////////////////////////////////
0288 //
0289 #include "FWCore/Framework/interface/MakerMacros.h"
0290 DEFINE_FWK_MODULE(LowPtGsfElectronIDProducer);