File indexing completed on 2024-04-06 12:01:05
0001 #include <memory>
0002
0003 #include "CommonTools/MVAUtils/interface/GBRForestTools.h"
0004 #include "CommonTools/MVAUtils/interface/TMVAEvaluator.h"
0005 #include "CommonTools/MVAUtils/interface/TMVAZipReader.h"
0006
0007 #include "CondFormats/DataRecord/interface/GBRWrapperRcd.h"
0008 #include "FWCore/Framework/interface/ESHandle.h"
0009 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0010
0011 TMVAEvaluator::TMVAEvaluator() : mIsInitialized(false), mUsingGBRForest(false), mUseAdaBoost(false) {}
0012
0013 void TMVAEvaluator::initialize(const std::string& options,
0014 const std::string& method,
0015 const std::string& weightFile,
0016 const std::vector<std::string>& variables,
0017 const std::vector<std::string>& spectators,
0018 bool useGBRForest,
0019 bool useAdaBoost) {
0020
0021 mReader = std::make_unique<TMVA::Reader>(options.c_str());
0022 mReader->SetVerbose(false);
0023 mMethod = method;
0024
0025
0026 for (std::vector<std::string>::const_iterator it = variables.begin(); it != variables.end(); ++it) {
0027 mVariables.insert(std::make_pair(*it, std::make_pair(it - variables.begin(), 0.)));
0028 mReader->AddVariable(it->c_str(), &(mVariables.at(*it).second));
0029 }
0030
0031
0032 for (std::vector<std::string>::const_iterator it = spectators.begin(); it != spectators.end(); ++it) {
0033 mSpectators.insert(std::make_pair(*it, std::make_pair(it - spectators.begin(), 0.)));
0034 mReader->AddSpectator(it->c_str(), &(mSpectators.at(*it).second));
0035 }
0036
0037
0038 reco::details::loadTMVAWeights(mReader.get(), mMethod, weightFile);
0039
0040 if (useGBRForest) {
0041 mGBRForest = createGBRForest(weightFile);
0042
0043
0044 mReader.reset(nullptr);
0045
0046 mUsingGBRForest = true;
0047 mUseAdaBoost = useAdaBoost;
0048 }
0049
0050 mIsInitialized = true;
0051 }
0052
0053 void TMVAEvaluator::initializeGBRForest(const GBRForest* gbrForest,
0054 const std::vector<std::string>& variables,
0055 const std::vector<std::string>& spectators,
0056 bool useAdaBoost) {
0057
0058 for (std::vector<std::string>::const_iterator it = variables.begin(); it != variables.end(); ++it)
0059 mVariables.insert(std::make_pair(*it, std::make_pair(it - variables.begin(), 0.)));
0060
0061
0062 for (std::vector<std::string>::const_iterator it = spectators.begin(); it != spectators.end(); ++it)
0063 mSpectators.insert(std::make_pair(*it, std::make_pair(it - spectators.begin(), 0.)));
0064
0065
0066 mGBRForest = std::shared_ptr<const GBRForest>(gbrForest, [](const GBRForest*) {});
0067
0068 mIsInitialized = true;
0069 mUsingGBRForest = true;
0070 mUseAdaBoost = useAdaBoost;
0071 }
0072
0073 float TMVAEvaluator::evaluateTMVA(const std::map<std::string, float>& inputs, bool useSpectators) const {
0074
0075 float value = -99.;
0076
0077
0078 std::lock_guard<std::mutex> lock(m_mutex);
0079
0080
0081 for (auto it = mVariables.begin(); it != mVariables.end(); ++it) {
0082 if (inputs.count(it->first) > 0)
0083 it->second.second = inputs.at(it->first);
0084 else
0085 edm::LogError("MissingInputVariable")
0086 << "Input variable " << it->first
0087 << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
0088 }
0089
0090
0091 if (useSpectators) {
0092
0093 for (auto it = mSpectators.begin(); it != mSpectators.end(); ++it) {
0094 if (inputs.count(it->first) > 0)
0095 it->second.second = inputs.at(it->first);
0096 else
0097 edm::LogError("MissingSpectatorVariable")
0098 << "Spectator variable " << it->first
0099 << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
0100 }
0101 }
0102
0103
0104 value = mReader->EvaluateMVA(mMethod.c_str());
0105
0106 return value;
0107 }
0108
0109 float TMVAEvaluator::evaluateGBRForest(const std::map<std::string, float>& inputs) const {
0110
0111 float value = -99.;
0112
0113 std::unique_ptr<float[]> vars(new float[mVariables.size()]);
0114
0115
0116 for (auto it = mVariables.begin(); it != mVariables.end(); ++it) {
0117 if (inputs.count(it->first) > 0)
0118 vars[it->second.first] = inputs.at(it->first);
0119 else
0120 edm::LogError("MissingInputVariable")
0121 << "Input variable " << it->first
0122 << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
0123 }
0124
0125
0126 if (mUseAdaBoost)
0127 value = mGBRForest->GetAdaBoostClassifier(vars.get());
0128 else
0129 value = mGBRForest->GetGradBoostClassifier(vars.get());
0130
0131 return value;
0132 }
0133
0134 float TMVAEvaluator::evaluate(const std::map<std::string, float>& inputs, bool useSpectators) const {
0135
0136 float value = -99.;
0137
0138 if (!mIsInitialized) {
0139 edm::LogError("InitializationError") << "TMVAEvaluator not properly initialized.";
0140 return value;
0141 }
0142
0143 if (useSpectators && inputs.size() < (mVariables.size() + mSpectators.size())) {
0144 edm::LogError("MissingInputs") << "Too few inputs provided (" << inputs.size() << " provided but "
0145 << mVariables.size() << " input and " << mSpectators.size()
0146 << " spectator variables expected).";
0147 return value;
0148 } else if (inputs.size() < mVariables.size()) {
0149 edm::LogError("MissingInputVariable(s)") << "Too few input variables provided (" << inputs.size()
0150 << " provided but " << mVariables.size() << " expected).";
0151 return value;
0152 }
0153
0154 if (mUsingGBRForest) {
0155 if (useSpectators)
0156 edm::LogWarning("UnsupportedFunctionality")
0157 << "Use of spectator variables with GBRForest is not supported. Spectator variables will be ignored.";
0158 value = evaluateGBRForest(inputs);
0159 } else
0160 value = evaluateTMVA(inputs, useSpectators);
0161
0162 return value;
0163 }