Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:25:08

0001 #include "RecoEgamma/ElectronIdentification/interface/ElectronMVAEstimatorRun2.h"
0002 
0003 #include "CommonTools/MVAUtils/interface/GBRForestTools.h"
0004 #include "DataFormats/PatCandidates/interface/Electron.h"
0005 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0006 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0007 #include "RecoEgamma/EgammaTools/interface/MVAVariableHelper.h"
0008 
0009 ElectronMVAEstimatorRun2::ElectronMVAEstimatorRun2(const edm::ParameterSet& conf)
0010     : AnyMVAEstimatorRun2Base(conf),
0011       mvaVarMngr_{conf.getParameter<std::string>("variableDefinition"), MVAVariableHelper::indexMap()} {
0012   const auto weightFileNames = conf.getParameter<std::vector<std::string> >("weightFileNames");
0013   const auto categoryCutStrings = conf.getParameter<std::vector<std::string> >("categoryCuts");
0014 
0015   if ((int)(categoryCutStrings.size()) != getNCategories())
0016     throw cms::Exception("MVA config failure: ")
0017         << "wrong number of category cuts in ElectronMVAEstimatorRun2" << getTag() << std::endl;
0018 
0019   for (int i = 0; i < getNCategories(); ++i) {
0020     categoryFunctions_.emplace_back(categoryCutStrings[i]);
0021   }
0022 
0023   // Initialize GBRForests from weight files
0024   init(weightFileNames);
0025 }
0026 
0027 ElectronMVAEstimatorRun2::ElectronMVAEstimatorRun2(const std::string& mvaTag,
0028                                                    const std::string& mvaName,
0029                                                    int nCategories,
0030                                                    const std::string& variableDefinition,
0031                                                    const std::vector<std::string>& categoryCutStrings,
0032                                                    const std::vector<std::string>& weightFileNames,
0033                                                    bool debug)
0034     : AnyMVAEstimatorRun2Base(mvaName, mvaTag, nCategories, debug),
0035       mvaVarMngr_{variableDefinition, MVAVariableHelper::indexMap()} {
0036   if ((int)(categoryCutStrings.size()) != getNCategories())
0037     throw cms::Exception("MVA config failure: ")
0038         << "wrong number of category cuts in " << getName() << getTag() << std::endl;
0039 
0040   for (auto const& cut : categoryCutStrings)
0041     categoryFunctions_.emplace_back(cut);
0042   init(weightFileNames);
0043 }
0044 
0045 void ElectronMVAEstimatorRun2::init(const std::vector<std::string>& weightFileNames) {
0046   if (isDebug()) {
0047     std::cout << " *** Inside ElectronMVAEstimatorRun2" << getTag() << std::endl;
0048   }
0049 
0050   // Initialize GBRForests
0051   if ((int)(weightFileNames.size()) != getNCategories())
0052     throw cms::Exception("MVA config failure: ")
0053         << "wrong number of weightfiles in ElectronMVAEstimatorRun2" << getTag() << std::endl;
0054 
0055   // Create a TMVA reader object for each category
0056   for (int i = 0; i < getNCategories(); i++) {
0057     std::vector<int> variablesInCategory;
0058 
0059     // Use unique_ptr so that all readers are properly cleaned up
0060     // when the vector clear() is called in the destructor
0061 
0062     std::vector<std::string> variableNamesInCategory;
0063     gbrForests_.push_back(createGBRForest(weightFileNames[i], variableNamesInCategory));
0064 
0065     nVariables_.push_back(variableNamesInCategory.size());
0066 
0067     variables_.push_back(variablesInCategory);
0068 
0069     if (isDebug()) {
0070       std::cout << " *** Inside ElectronMVAEstimatorRun2" << getTag() << std::endl;
0071       std::cout << " category " << i << " with nVariables " << nVariables_[i] << std::endl;
0072     }
0073 
0074     for (int j = 0; j < nVariables_[i]; ++j) {
0075       int index = mvaVarMngr_.getVarIndex(variableNamesInCategory[j]);
0076       if (index == -1) {
0077         throw cms::Exception("MVA config failure: ")
0078             << "Concerning ElectronMVAEstimatorRun2" << getTag() << std::endl
0079             << "Variable " << variableNamesInCategory[j] << " not found in variable definition file!" << std::endl;
0080       }
0081       variables_[i].push_back(index);
0082     }
0083   }
0084 }
0085 
0086 float ElectronMVAEstimatorRun2::mvaValue(const reco::Candidate* candidate,
0087                                          const std::vector<float>& auxVariables,
0088                                          int& iCategory) const {
0089   const reco::GsfElectron* electron = dynamic_cast<const reco::GsfElectron*>(candidate);
0090 
0091   if (electron == nullptr) {
0092     throw cms::Exception("MVA failure: ")
0093         << " given particle is expected to be reco::GsfElectron or pat::Electron," << std::endl
0094         << " but appears to be neither" << std::endl;
0095   }
0096 
0097   iCategory = findCategory(electron);
0098 
0099   if (iCategory < 0)
0100     return -999;
0101 
0102   std::vector<float> vars;
0103 
0104   vars.reserve(nVariables_[iCategory]);
0105   for (int i = 0; i < nVariables_[iCategory]; ++i) {
0106     vars.push_back(mvaVarMngr_.getValue(variables_[iCategory][i], *electron, auxVariables));
0107   }
0108 
0109   if (isDebug()) {
0110     std::cout << " *** Inside ElectronMVAEstimatorRun2" << getTag() << std::endl;
0111     std::cout << " category " << iCategory << std::endl;
0112     for (int i = 0; i < nVariables_[iCategory]; ++i) {
0113       std::cout << " " << mvaVarMngr_.getName(variables_[iCategory][i]) << " " << vars[i] << std::endl;
0114     }
0115   }
0116   const float response = gbrForests_.at(iCategory)->GetResponse(vars.data());  // The BDT score
0117 
0118   if (isDebug()) {
0119     std::cout << " ### MVA " << response << std::endl << std::endl;
0120   }
0121 
0122   return response;
0123 }
0124 
0125 int ElectronMVAEstimatorRun2::findCategory(const reco::Candidate* candidate) const {
0126   const reco::GsfElectron* electron = dynamic_cast<reco::GsfElectron const*>(candidate);
0127 
0128   if (electron == nullptr) {
0129     throw cms::Exception("MVA failure: ")
0130         << " given particle is expected to be reco::GsfElectron or pat::Electron," << std::endl
0131         << " but appears to be neither" << std::endl;
0132   }
0133 
0134   return findCategory(*electron);
0135 }
0136 
0137 int ElectronMVAEstimatorRun2::findCategory(reco::GsfElectron const& electron) const {
0138   for (int i = 0; i < getNCategories(); ++i) {
0139     if (categoryFunctions_[i](electron))
0140       return i;
0141   }
0142 
0143   edm::LogWarning("MVA warning") << "category not defined for particle with pt " << electron.pt() << " GeV, eta "
0144                                  << electron.superCluster()->eta() << " in ElectronMVAEstimatorRun2" << getTag();
0145 
0146   return -1;
0147 }