Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-06-18 02:20:34

0001 #ifndef PhysicsTools_XGBoost_XGBooster_h
0002 #define PhysicsTools_XGBoost_XGBooster_h
0003 
0004 #include <memory>
0005 #include <string>
0006 #include <vector>
0007 #include <map>
0008 #include <xgboost/c_api.h>
0009 
0010 namespace pat {
0011   class XGBooster {
0012   public:
0013     XGBooster(std::string model_file);
0014     XGBooster(std::string model_file, std::string model_features);
0015 
0016     /// Features need to be entered in the order they are used
0017     /// in the model
0018     void addFeature(std::string name);
0019 
0020     /// Reset feature values
0021     void reset();
0022 
0023     void set(std::string name, float value);
0024 
0025     float predict(const int iterationEnd = 0);
0026     float predict(const std::vector<float>& features, const int iterationEnd = 0) const;
0027 
0028   private:
0029     std::vector<float> features_;
0030     std::map<std::string, unsigned int> feature_name_to_index_;
0031     BoosterHandle booster_;
0032   };
0033 }  // namespace pat
0034 
0035 #endif