Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-06-18 02:20:34

0001 #include <algorithm>
0002 #include <cassert>
0003 #include <cmath>
0004 #include <cstdio>  // For std::snprintf
0005 #include <fstream>
0006 #include <iostream>
0007 #include <sstream>
0008 #include <stdexcept>
0009 #include <stdexcept>
0010 #include <vector>
0011 
0012 #include "PhysicsTools/XGBoost/interface/XGBooster.h"
0013 
0014 using namespace pat;
0015 
0016 std::vector<std::string> read_features(const std::string& content) {
0017   std::vector<std::string> result;
0018 
0019   std::istringstream stream(content);
0020   char ch;
0021 
0022   // Expect opening '['
0023   stream >> ch;
0024   if (ch != '[') {
0025     throw std::runtime_error("Expected '[' at the beginning of the JSON array!");
0026   }
0027 
0028   while (stream) {
0029     stream >> ch;
0030 
0031     if (ch == ']') {
0032       break;
0033     } else if (ch == ',') {
0034       continue;
0035     } else if (ch == '"') {
0036       std::string feature;
0037       std::getline(stream, feature, '"');
0038       result.push_back(feature);
0039     } else {
0040       throw std::runtime_error("Unexpected character in the JSON array!");
0041     }
0042   }
0043 
0044   return result;
0045 }
0046 
0047 XGBooster::XGBooster(std::string model_file) {
0048   int status = XGBoosterCreate(nullptr, 0, &booster_);
0049   if (status != 0)
0050     throw std::runtime_error("Failed to create XGBooster");
0051   status = XGBoosterLoadModel(booster_, model_file.c_str());
0052   if (status != 0)
0053     throw std::runtime_error("Failed to load XGBoost model");
0054   XGBoosterSetParam(booster_, "nthread", "1");
0055 }
0056 
0057 XGBooster::XGBooster(std::string model_file, std::string model_features) : XGBooster(model_file) {
0058   std::ifstream file(model_features);
0059   if (!file.is_open())
0060     throw std::runtime_error("Failed to open file: " + model_features);
0061 
0062   std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
0063   file.close();
0064 
0065   std::vector<std::string> features = read_features(content);
0066 
0067   for (const auto& feature : features) {
0068     addFeature(feature);
0069   }
0070 }
0071 
0072 void XGBooster::reset() { std::fill(features_.begin(), features_.end(), std::nan("")); }
0073 
0074 void XGBooster::addFeature(std::string name) {
0075   features_.push_back(0);
0076   feature_name_to_index_[name] = features_.size() - 1;
0077 }
0078 
0079 void XGBooster::set(std::string name, float value) { features_.at(feature_name_to_index_[name]) = value; }
0080 
0081 float XGBooster::predict(const int iterationEnd) {
0082   // check if all feature values are set properly
0083   for (unsigned int i = 0; i < features_.size(); ++i)
0084     if (std::isnan(features_.at(i))) {
0085       std::string feature_name;
0086       for (const auto& pair : feature_name_to_index_) {
0087         if (pair.second == i) {
0088           feature_name = pair.first;
0089           break;
0090         }
0091       }
0092       throw std::runtime_error("Feature is not set: " + feature_name);
0093     }
0094 
0095   float const ret = predict(features_, iterationEnd);
0096 
0097   reset();
0098 
0099   return ret;
0100 }
0101 
0102 float XGBooster::predict(const std::vector<float>& features, const int iterationEnd) const {
0103   float result{-999.};
0104 
0105   if (features.empty()) {
0106     throw std::runtime_error("Vector of input features is empty");
0107   }
0108 
0109   if (feature_name_to_index_.size() != features.size())
0110     throw std::runtime_error("Feature size mismatch");
0111 
0112   DMatrixHandle dvalues;
0113   XGDMatrixCreateFromMat(&features[0], 1, features.size(), 9e99, &dvalues);
0114 
0115   bst_ulong out_len = 0;
0116   const float* score = nullptr;
0117 
0118   char json[256];  // Make sure the buffer is large enough to hold the resulting JSON string
0119 
0120   // Use snprintf to format the JSON string with the external value
0121   std::snprintf(json,
0122                 sizeof(json),
0123                 R"({
0124     "type": 0,
0125     "training": false,
0126     "iteration_begin": 0,
0127     "iteration_end": %d,
0128     "strict_shape": false
0129    })",
0130                 iterationEnd);
0131 
0132   // Shape of output prediction
0133   bst_ulong const* out_shape = nullptr;
0134 
0135   auto ret = XGBoosterPredictFromDMatrix(booster_, dvalues, json, &out_shape, &out_len, &score);
0136 
0137   if (ret == 0) {
0138     assert(out_len == 1 && "Unexpected prediction format");
0139     result = score[0];
0140   }
0141 
0142   XGDMatrixFree(dvalues);
0143 
0144   return result;
0145 }