File indexing completed on 2023-03-17 10:45:19
0001 #ifndef CommonTools_MVAUtils_TMVAEvaluator_h
0002 #define CommonTools_MVAUtils_TMVAEvaluator_h
0003
0004 #include <map>
0005 #include <memory>
0006 #include <mutex>
0007 #include <string>
0008 #include <vector>
0009
0010 #include "CondFormats/GBRForest/interface/GBRForest.h"
0011 #include "FWCore/Framework/interface/EventSetup.h"
0012 #include "FWCore/Utilities/interface/thread_safety_macros.h"
0013 #include "TMVA/IMethod.h"
0014 #include "TMVA/Reader.h"
0015
0016 class TMVAEvaluator {
0017 public:
0018 TMVAEvaluator();
0019
0020 void initialize(const std::string& options,
0021 const std::string& method,
0022 const std::string& weightFile,
0023 const std::vector<std::string>& variables,
0024 const std::vector<std::string>& spectators,
0025 bool useGBRForest = false,
0026 bool useAdaBoost = false);
0027
0028 void initializeGBRForest(const GBRForest* gbrForest,
0029 const std::vector<std::string>& variables,
0030 const std::vector<std::string>& spectators,
0031 bool useAdaBoost = false);
0032
0033 float evaluateTMVA(const std::map<std::string, float>& inputs, bool useSpectators) const;
0034 float evaluateGBRForest(const std::map<std::string, float>& inputs) const;
0035 float evaluate(const std::map<std::string, float>& inputs, bool useSpectators = false) const;
0036
0037 private:
0038 bool mIsInitialized;
0039 bool mUsingGBRForest;
0040 bool mUseAdaBoost;
0041
0042 std::string mMethod;
0043 mutable std::mutex m_mutex;
0044 CMS_THREAD_GUARD(m_mutex) std::unique_ptr<TMVA::Reader> mReader;
0045 std::shared_ptr<const GBRForest> mGBRForest;
0046
0047 CMS_THREAD_GUARD(m_mutex) mutable std::map<std::string, std::pair<size_t, float>> mVariables;
0048 CMS_THREAD_GUARD(m_mutex) mutable std::map<std::string, std::pair<size_t, float>> mSpectators;
0049 };
0050
0051 #endif