Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-03-17 10:45:19

0001 #ifndef CommonTools_MVAUtils_TMVAEvaluator_h
0002 #define CommonTools_MVAUtils_TMVAEvaluator_h
0003 
0004 #include <map>
0005 #include <memory>
0006 #include <mutex>
0007 #include <string>
0008 #include <vector>
0009 
0010 #include "CondFormats/GBRForest/interface/GBRForest.h"
0011 #include "FWCore/Framework/interface/EventSetup.h"
0012 #include "FWCore/Utilities/interface/thread_safety_macros.h"
0013 #include "TMVA/IMethod.h"
0014 #include "TMVA/Reader.h"
0015 
0016 class TMVAEvaluator {
0017 public:
0018   TMVAEvaluator();
0019 
0020   void initialize(const std::string& options,
0021                   const std::string& method,
0022                   const std::string& weightFile,
0023                   const std::vector<std::string>& variables,
0024                   const std::vector<std::string>& spectators,
0025                   bool useGBRForest = false,
0026                   bool useAdaBoost = false);
0027 
0028   void initializeGBRForest(const GBRForest* gbrForest,
0029                            const std::vector<std::string>& variables,
0030                            const std::vector<std::string>& spectators,
0031                            bool useAdaBoost = false);
0032 
0033   float evaluateTMVA(const std::map<std::string, float>& inputs, bool useSpectators) const;
0034   float evaluateGBRForest(const std::map<std::string, float>& inputs) const;
0035   float evaluate(const std::map<std::string, float>& inputs, bool useSpectators = false) const;
0036 
0037 private:
0038   bool mIsInitialized;
0039   bool mUsingGBRForest;
0040   bool mUseAdaBoost;
0041 
0042   std::string mMethod;
0043   mutable std::mutex m_mutex;
0044   CMS_THREAD_GUARD(m_mutex) std::unique_ptr<TMVA::Reader> mReader;
0045   std::shared_ptr<const GBRForest> mGBRForest;
0046 
0047   CMS_THREAD_GUARD(m_mutex) mutable std::map<std::string, std::pair<size_t, float>> mVariables;
0048   CMS_THREAD_GUARD(m_mutex) mutable std::map<std::string, std::pair<size_t, float>> mSpectators;
0049 };
0050 
0051 #endif  // CommonTools_Utils_TMVAEvaluator_h