Line Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
#include <memory>

#include "CommonTools/MVAUtils/interface/GBRForestTools.h"
#include "CommonTools/MVAUtils/interface/TMVAEvaluator.h"
#include "CommonTools/MVAUtils/interface/TMVAZipReader.h"

#include "CondFormats/DataRecord/interface/GBRWrapperRcd.h"
#include "FWCore/Framework/interface/ESHandle.h"
#include "FWCore/MessageLogger/interface/MessageLogger.h"

TMVAEvaluator::TMVAEvaluator() : mIsInitialized(false), mUsingGBRForest(false), mUseAdaBoost(false) {}

void TMVAEvaluator::initialize(const std::string& options,
                               const std::string& method,
                               const std::string& weightFile,
                               const std::vector<std::string>& variables,
                               const std::vector<std::string>& spectators,
                               bool useGBRForest,
                               bool useAdaBoost) {
  // initialize the TMVA reader
  mReader = std::make_unique<TMVA::Reader>(options.c_str());
  mReader->SetVerbose(false);
  mMethod = method;

  // add input variables
  for (std::vector<std::string>::const_iterator it = variables.begin(); it != variables.end(); ++it) {
    mVariables.insert(std::make_pair(*it, std::make_pair(it - variables.begin(), 0.)));
    mReader->AddVariable(it->c_str(), &(mVariables.at(*it).second));
  }

  // add spectator variables
  for (std::vector<std::string>::const_iterator it = spectators.begin(); it != spectators.end(); ++it) {
    mSpectators.insert(std::make_pair(*it, std::make_pair(it - spectators.begin(), 0.)));
    mReader->AddSpectator(it->c_str(), &(mSpectators.at(*it).second));
  }

  // load the TMVA weights
  reco::details::loadTMVAWeights(mReader.get(), mMethod, weightFile);

  if (useGBRForest) {
    mGBRForest = createGBRForest(weightFile);

    // now can free some memory
    mReader.reset(nullptr);

    mUsingGBRForest = true;
    mUseAdaBoost = useAdaBoost;
  }

  mIsInitialized = true;
}

void TMVAEvaluator::initializeGBRForest(const GBRForest* gbrForest,
                                        const std::vector<std::string>& variables,
                                        const std::vector<std::string>& spectators,
                                        bool useAdaBoost) {
  // add input variables
  for (std::vector<std::string>::const_iterator it = variables.begin(); it != variables.end(); ++it)
    mVariables.insert(std::make_pair(*it, std::make_pair(it - variables.begin(), 0.)));

  // add spectator variables
  for (std::vector<std::string>::const_iterator it = spectators.begin(); it != spectators.end(); ++it)
    mSpectators.insert(std::make_pair(*it, std::make_pair(it - spectators.begin(), 0.)));

  // do not take ownership if getting GBRForest from an external source
  mGBRForest = std::shared_ptr<const GBRForest>(gbrForest, [](const GBRForest*) {});

  mIsInitialized = true;
  mUsingGBRForest = true;
  mUseAdaBoost = useAdaBoost;
}

float TMVAEvaluator::evaluateTMVA(const std::map<std::string, float>& inputs, bool useSpectators) const {
  // default value
  float value = -99.;

  // TMVA::Reader is not thread safe
  std::lock_guard<std::mutex> lock(m_mutex);

  // set the input variable values
  for (auto it = mVariables.begin(); it != mVariables.end(); ++it) {
    if (inputs.count(it->first) > 0)
      it->second.second = inputs.at(it->first);
    else
      edm::LogError("MissingInputVariable")
          << "Input variable " << it->first
          << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
  }

  // if using spectator variables
  if (useSpectators) {
    // set the spectator variable values
    for (auto it = mSpectators.begin(); it != mSpectators.end(); ++it) {
      if (inputs.count(it->first) > 0)
        it->second.second = inputs.at(it->first);
      else
        edm::LogError("MissingSpectatorVariable")
            << "Spectator variable " << it->first
            << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
    }
  }

  // evaluate the MVA
  value = mReader->EvaluateMVA(mMethod.c_str());

  return value;
}

float TMVAEvaluator::evaluateGBRForest(const std::map<std::string, float>& inputs) const {
  // default value
  float value = -99.;

  std::unique_ptr<float[]> vars(new float[mVariables.size()]);  // allocate n floats

  // set the input variable values
  for (auto it = mVariables.begin(); it != mVariables.end(); ++it) {
    if (inputs.count(it->first) > 0)
      vars[it->second.first] = inputs.at(it->first);
    else
      edm::LogError("MissingInputVariable")
          << "Input variable " << it->first
          << " is missing from the list of inputs. The returned discriminator value might not be sensible.";
  }

  // evaluate the MVA
  if (mUseAdaBoost)
    value = mGBRForest->GetAdaBoostClassifier(vars.get());
  else
    value = mGBRForest->GetGradBoostClassifier(vars.get());

  return value;
}

float TMVAEvaluator::evaluate(const std::map<std::string, float>& inputs, bool useSpectators) const {
  // default value
  float value = -99.;

  if (!mIsInitialized) {
    edm::LogError("InitializationError") << "TMVAEvaluator not properly initialized.";
    return value;
  }

  if (useSpectators && inputs.size() < (mVariables.size() + mSpectators.size())) {
    edm::LogError("MissingInputs") << "Too few inputs provided (" << inputs.size() << " provided but "
                                   << mVariables.size() << " input and " << mSpectators.size()
                                   << " spectator variables expected).";
    return value;
  } else if (inputs.size() < mVariables.size()) {
    edm::LogError("MissingInputVariable(s)") << "Too few input variables provided (" << inputs.size()
                                             << " provided but " << mVariables.size() << " expected).";
    return value;
  }

  if (mUsingGBRForest) {
    if (useSpectators)
      edm::LogWarning("UnsupportedFunctionality")
          << "Use of spectator variables with GBRForest is not supported. Spectator variables will be ignored.";
    value = evaluateGBRForest(inputs);
  } else
    value = evaluateTMVA(inputs, useSpectators);

  return value;
}