Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-09-07 04:37:30

0001 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0002 #include "CommonTools/MVAUtils/interface/GBRForestTools.h"
0003 #include "FWCore/Framework/interface/ConsumesCollector.h"
0004 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0005 #include "DataFormats/Common/interface/ValueMap.h"
0006 #include "RecoEgamma/EgammaTools/interface/AnyMVAEstimatorRun2Base.h"
0007 #include "RecoEgamma/EgammaTools/interface/AnyMVAEstimatorRun2Factory.h"
0008 #include "DataFormats/EgammaCandidates/interface/Photon.h"
0009 #include "CommonTools/Egamma/interface/EffectiveAreas.h"
0010 #include "CondFormats/GBRForest/interface/GBRForest.h"
0011 #include "RecoEgamma/EgammaTools/interface/MVAVariableHelper.h"
0012 #include "RecoEgamma/EgammaTools/interface/MVAVariableManager.h"
0013 #include "CommonTools/Utils/interface/StringCutObjectSelector.h"
0014 #include "CommonTools/Utils/interface/ThreadSafeFunctor.h"
0015 
0016 class PhotonMVAEstimator : public AnyMVAEstimatorRun2Base {
0017 public:
0018   // Constructor and destructor
0019   PhotonMVAEstimator(const edm::ParameterSet& conf);
0020   ~PhotonMVAEstimator() override {}
0021 
0022   // Calculation of the MVA value
0023   float mvaValue(const reco::Candidate* candPtr, std::vector<float> const& auxVars, int& iCategory) const override;
0024 
0025   int findCategory(const reco::Candidate* candPtr) const override;
0026 
0027   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0028 
0029 private:
0030   int findCategory(reco::Photon const& photon) const;
0031 
0032   // The number of categories and number of variables per category
0033   int nCategories_;
0034   std::vector<ThreadSafeFunctor<StringCutObjectSelector<reco::Photon>>> categoryFunctions_;
0035   std::vector<int> nVariables_;
0036 
0037   // Data members
0038   std::vector<std::unique_ptr<const GBRForest>> gbrForests_;
0039 
0040   // There might be different variables for each category, so the variables
0041   // names vector is itself a vector of length nCategories
0042   std::vector<std::vector<int>> variables_;
0043 
0044   // The variable manager which stores how to obtain the variables
0045   MVAVariableManager<reco::Photon> mvaVarMngr_;
0046 
0047   // Other objects needed by the MVA
0048   std::unique_ptr<EffectiveAreas> effectiveAreas_;
0049   std::vector<double> phoIsoPtScalingCoeff_;
0050   double phoIsoCutoff_;
0051 };
0052 
0053 PhotonMVAEstimator::PhotonMVAEstimator(const edm::ParameterSet& conf)
0054     : AnyMVAEstimatorRun2Base(conf),
0055       mvaVarMngr_(conf.getParameter<std::string>("variableDefinition"), MVAVariableHelper::indexMap()) {
0056   //
0057   // Construct the MVA estimators
0058   //
0059   if (getTag() == "Run2Spring16NonTrigV1") {
0060     effectiveAreas_ =
0061         std::make_unique<EffectiveAreas>((conf.getParameter<edm::FileInPath>("effAreasConfigFile")).fullPath());
0062     phoIsoPtScalingCoeff_ = conf.getParameter<std::vector<double>>("phoIsoPtScalingCoeff");
0063     phoIsoCutoff_ = conf.getParameter<double>("phoIsoCutoff");
0064   }
0065 
0066   const auto weightFileNames = conf.getParameter<std::vector<std::string>>("weightFileNames");
0067   const auto categoryCutStrings = conf.getParameter<std::vector<std::string>>("categoryCuts");
0068 
0069   if ((int)(categoryCutStrings.size()) != getNCategories())
0070     throw cms::Exception("MVA config failure: ")
0071         << "wrong number of category cuts in PhotonMVAEstimator" << getTag() << std::endl;
0072 
0073   for (auto const& cut : categoryCutStrings)
0074     categoryFunctions_.emplace_back(cut);
0075 
0076   // Initialize GBRForests
0077   if (static_cast<int>(weightFileNames.size()) != getNCategories())
0078     throw cms::Exception("MVA config failure: ")
0079         << "wrong number of weightfiles in PhotonMVAEstimator" << getTag() << std::endl;
0080 
0081   gbrForests_.clear();
0082   // Create a TMVA reader object for each category
0083   for (int i = 0; i < getNCategories(); i++) {
0084     std::vector<int> variablesInCategory;
0085 
0086     std::vector<std::string> variableNamesInCategory;
0087     gbrForests_.push_back(createGBRForest(weightFileNames[i], variableNamesInCategory));
0088 
0089     nVariables_.push_back(variableNamesInCategory.size());
0090 
0091     variables_.push_back(variablesInCategory);
0092 
0093     for (int j = 0; j < nVariables_[i]; ++j) {
0094       int index = mvaVarMngr_.getVarIndex(variableNamesInCategory[j]);
0095       if (index == -1) {
0096         throw cms::Exception("MVA config failure: ")
0097             << "Concerning PhotonMVAEstimator" << getTag() << std::endl
0098             << "Variable " << variableNamesInCategory[j] << " not found in variable definition file!" << std::endl;
0099       }
0100       variables_[i].push_back(index);
0101     }
0102   }
0103 }
0104 
0105 float PhotonMVAEstimator::mvaValue(const reco::Candidate* candPtr,
0106                                    std::vector<float> const& auxVars,
0107                                    int& iCategory) const {
0108   const reco::Photon* phoPtr = dynamic_cast<const reco::Photon*>(candPtr);
0109   if (phoPtr == nullptr) {
0110     throw cms::Exception("MVA failure: ")
0111         << " given particle is expected to be reco::Photon or pat::Photon," << std::endl
0112         << " but appears to be neither" << std::endl;
0113   }
0114 
0115   iCategory = findCategory(phoPtr);
0116 
0117   std::vector<float> vars;
0118 
0119   vars.reserve(nVariables_[iCategory]);
0120   for (int i = 0; i < nVariables_[iCategory]; ++i) {
0121     vars.push_back(mvaVarMngr_.getValue(variables_[iCategory][i], *phoPtr, auxVars));
0122   }
0123 
0124   // Special case for Spring16!
0125   if (getTag() == "Run2Spring16NonTrigV1" and iCategory == 1) {  // Endcap category
0126     // Raw value for EB only, because of loss of transparency in EE
0127     // for endcap MVA only in 2016
0128     double eA = effectiveAreas_->getEffectiveArea(std::abs(phoPtr->superCluster()->eta()));
0129     double phoIsoCorr = vars[10] - eA * (double)vars[9] - phoIsoPtScalingCoeff_.at(1) * phoPtr->pt();
0130     vars[10] = std::max(phoIsoCorr, phoIsoCutoff_);
0131   }
0132 
0133   if (isDebug()) {
0134     std::cout << " *** Inside PhotonMVAEstimator" << getTag() << std::endl;
0135     std::cout << " category " << iCategory << std::endl;
0136     for (int i = 0; i < nVariables_[iCategory]; ++i) {
0137       std::cout << " " << mvaVarMngr_.getName(variables_[iCategory][i]) << " " << vars[i] << std::endl;
0138     }
0139   }
0140 
0141   const float response = gbrForests_.at(iCategory)->GetResponse(vars.data());
0142 
0143   if (isDebug()) {
0144     std::cout << " ### MVA " << response << std::endl << std::endl;
0145   }
0146 
0147   return response;
0148 }
0149 
0150 int PhotonMVAEstimator::findCategory(const reco::Candidate* candPtr) const {
0151   const reco::Photon* phoPtr = dynamic_cast<const reco::Photon*>(candPtr);
0152   if (phoPtr == nullptr) {
0153     throw cms::Exception("MVA failure: ")
0154         << " given particle is expected to be reco::Photon or pat::Photon," << std::endl
0155         << " but appears to be neither" << std::endl;
0156   }
0157 
0158   return findCategory(*phoPtr);
0159 }
0160 
0161 int PhotonMVAEstimator::findCategory(reco::Photon const& photon) const {
0162   for (int i = 0; i < getNCategories(); ++i) {
0163     if (categoryFunctions_[i](photon))
0164       return i;
0165   }
0166 
0167   edm::LogWarning("MVA warning") << "category not defined for particle with pt " << photon.pt() << " GeV, eta "
0168                                  << photon.superCluster()->eta() << " in PhotonMVAEstimator" << getTag();
0169 
0170   return -1;
0171 }
0172 
0173 DEFINE_EDM_PLUGIN(AnyMVAEstimatorRun2Factory, PhotonMVAEstimator, "PhotonMVAEstimator");