Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:01:05

0001 #include <memory>
0002 
0003 #include "CommonTools/MVAUtils/interface/GBRForestTools.h"
0004 #include "CommonTools/MVAUtils/interface/TMVAEvaluator.h"
0005 #include "CommonTools/MVAUtils/interface/TMVAZipReader.h"
0006 
0007 #include "CondFormats/DataRecord/interface/GBRWrapperRcd.h"
0008 #include "FWCore/Framework/interface/ESHandle.h"
0009 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0010 
0011 TMVAEvaluator::TMVAEvaluator() : mIsInitialized(false), mUsingGBRForest(false), mUseAdaBoost(false) {}
0012 
0013 void TMVAEvaluator::initialize(const std::string& options,
0014                                const std::string& method,
0015                                const std::string& weightFile,
0016                                const std::vector<std::string>& variables,
0017                                const std::vector<std::string>& spectators,
0018                                bool useGBRForest,
0019                                bool useAdaBoost) {
0020   // initialize the TMVA reader
0021   mReader = std::make_unique<TMVA::Reader>(options.c_str());
0022   mReader->SetVerbose(false);
0023   mMethod = method;
0024 
0025   // add input variables
0026   for (std::vector<std::string>::const_iterator it = variables.begin(); it != variables.end(); ++it) {
0027     mVariables.insert(std::make_pair(*it, std::make_pair(it - variables.begin(), 0.)));
0028     mReader->AddVariable(it->c_str(), &(mVariables.at(*it).second));
0029   }
0030 
0031   // add spectator variables
0032   for (std::vector<std::string>::const_iterator it = spectators.begin(); it != spectators.end(); ++it) {
0033     mSpectators.insert(std::make_pair(*it, std::make_pair(it - spectators.begin(), 0.)));
0034     mReader->AddSpectator(it->c_str(), &(mSpectators.at(*it).second));
0035   }
0036 
0037   // load the TMVA weights
0038   reco::details::loadTMVAWeights(mReader.get(), mMethod, weightFile);
0039 
0040   if (useGBRForest) {
0041     mGBRForest = createGBRForest(weightFile);
0042 
0043     // now can free some memory
0044     mReader.reset(nullptr);
0045 
0046     mUsingGBRForest = true;
0047     mUseAdaBoost = useAdaBoost;
0048   }
0049 
0050   mIsInitialized = true;
0051 }
0052 
0053 void TMVAEvaluator::initializeGBRForest(const GBRForest* gbrForest,
0054                                         const std::vector<std::string>& variables,
0055                                         const std::vector<std::string>& spectators,
0056                                         bool useAdaBoost) {
0057   // add input variables
0058   for (std::vector<std::string>::const_iterator it = variables.begin(); it != variables.end(); ++it)
0059     mVariables.insert(std::make_pair(*it, std::make_pair(it - variables.begin(), 0.)));
0060 
0061   // add spectator variables
0062   for (std::vector<std::string>::const_iterator it = spectators.begin(); it != spectators.end(); ++it)
0063     mSpectators.insert(std::make_pair(*it, std::make_pair(it - spectators.begin(), 0.)));
0064 
0065   // do not take ownership if getting GBRForest from an external source
0066   mGBRForest = std::shared_ptr<const GBRForest>(gbrForest, [](const GBRForest*) {});
0067 
0068   mIsInitialized = true;
0069   mUsingGBRForest = true;
0070   mUseAdaBoost = useAdaBoost;
0071 }
0072 
0073 float TMVAEvaluator::evaluateTMVA(const std::map<std::string, float>& inputs, bool useSpectators) const {
0074   // default value
0075   float value = -99.;
0076 
0077   // TMVA::Reader is not thread safe
0078   std::lock_guard<std::mutex> lock(m_mutex);
0079 
0080   // set the input variable values
0081   for (auto it = mVariables.begin(); it != mVariables.end(); ++it) {
0082     if (inputs.count(it->first) > 0)
0083       it->second.second = inputs.at(it->first);
0084     else
0085       edm::LogError("MissingInputVariable")
0086           << "Input variable " << it->first
0087           << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
0088   }
0089 
0090   // if using spectator variables
0091   if (useSpectators) {
0092     // set the spectator variable values
0093     for (auto it = mSpectators.begin(); it != mSpectators.end(); ++it) {
0094       if (inputs.count(it->first) > 0)
0095         it->second.second = inputs.at(it->first);
0096       else
0097         edm::LogError("MissingSpectatorVariable")
0098             << "Spectator variable " << it->first
0099             << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
0100     }
0101   }
0102 
0103   // evaluate the MVA
0104   value = mReader->EvaluateMVA(mMethod.c_str());
0105 
0106   return value;
0107 }
0108 
0109 float TMVAEvaluator::evaluateGBRForest(const std::map<std::string, float>& inputs) const {
0110   // default value
0111   float value = -99.;
0112 
0113   std::unique_ptr<float[]> vars(new float[mVariables.size()]);  // allocate n floats
0114 
0115   // set the input variable values
0116   for (auto it = mVariables.begin(); it != mVariables.end(); ++it) {
0117     if (inputs.count(it->first) > 0)
0118       vars[it->second.first] = inputs.at(it->first);
0119     else
0120       edm::LogError("MissingInputVariable")
0121           << "Input variable " << it->first
0122           << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
0123   }
0124 
0125   // evaluate the MVA
0126   if (mUseAdaBoost)
0127     value = mGBRForest->GetAdaBoostClassifier(vars.get());
0128   else
0129     value = mGBRForest->GetGradBoostClassifier(vars.get());
0130 
0131   return value;
0132 }
0133 
0134 float TMVAEvaluator::evaluate(const std::map<std::string, float>& inputs, bool useSpectators) const {
0135   // default value
0136   float value = -99.;
0137 
0138   if (!mIsInitialized) {
0139     edm::LogError("InitializationError") << "TMVAEvaluator not properly initialized.";
0140     return value;
0141   }
0142 
0143   if (useSpectators && inputs.size() < (mVariables.size() + mSpectators.size())) {
0144     edm::LogError("MissingInputs") << "Too few inputs provided (" << inputs.size() << " provided but "
0145                                    << mVariables.size() << " input and " << mSpectators.size()
0146                                    << " spectator variables expected).";
0147     return value;
0148   } else if (inputs.size() < mVariables.size()) {
0149     edm::LogError("MissingInputVariable(s)") << "Too few input variables provided (" << inputs.size()
0150                                              << " provided but " << mVariables.size() << " expected).";
0151     return value;
0152   }
0153 
0154   if (mUsingGBRForest) {
0155     if (useSpectators)
0156       edm::LogWarning("UnsupportedFunctionality")
0157           << "Use of spectator variables with GBRForest is not supported. Spectator variables will be ignored.";
0158     value = evaluateGBRForest(inputs);
0159   } else
0160     value = evaluateTMVA(inputs, useSpectators);
0161 
0162   return value;
0163 }