File indexing completed on 2024-04-06 12:01:05
0001 #include "CommonTools/MVAUtils/interface/GBRForestTools.h"
0002 #include "CommonTools/MVAUtils/interface/TMVAZipReader.h"
0003 #include "FWCore/ParameterSet/interface/FileInPath.h"
0004 #include "FWCore/Utilities/interface/Exception.h"
0005
0006 #include "TFile.h"
0007
0008 #include <cstdio>
0009 #include <cstdlib>
0010 #include <RVersion.h>
0011 #include <cmath>
0012 #include <tinyxml2.h>
0013 #include <filesystem>
0014
0015 namespace {
0016
0017 size_t readVariables(tinyxml2::XMLElement* root, const char* key, std::vector<std::string>& names) {
0018 size_t n = 0;
0019 names.clear();
0020
0021 if (root != nullptr) {
0022 for (tinyxml2::XMLElement* e = root->FirstChildElement(key); e != nullptr; e = e->NextSiblingElement(key)) {
0023 names.push_back(e->Attribute("Expression"));
0024 ++n;
0025 }
0026 }
0027
0028 return n;
0029 }
0030
0031 bool isTerminal(tinyxml2::XMLElement* node) {
0032 bool is = true;
0033 for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
0034 is = false;
0035 }
0036 return is;
0037 }
0038
0039 unsigned int countIntermediateNodes(tinyxml2::XMLElement* node) {
0040 unsigned int count = 0;
0041 for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
0042 count += countIntermediateNodes(e);
0043 }
0044 return count > 0 ? count + 1 : 0;
0045 }
0046
0047 unsigned int countTerminalNodes(tinyxml2::XMLElement* node) {
0048 unsigned int count = 0;
0049 for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
0050 count += countTerminalNodes(e);
0051 }
0052 return count > 0 ? count : 1;
0053 }
0054
0055 void addNode(GBRTree& tree,
0056 tinyxml2::XMLElement* node,
0057 double scale,
0058 bool isRegression,
0059 bool useYesNoLeaf,
0060 bool adjustboundary,
0061 bool isAdaClassifier) {
0062 bool nodeIsTerminal = isTerminal(node);
0063 if (nodeIsTerminal) {
0064 double response = 0.;
0065 if (isRegression) {
0066 node->QueryDoubleAttribute("res", &response);
0067 } else {
0068 if (useYesNoLeaf) {
0069 node->QueryDoubleAttribute("nType", &response);
0070 } else {
0071 if (isAdaClassifier) {
0072 node->QueryDoubleAttribute("purity", &response);
0073 } else {
0074 node->QueryDoubleAttribute("res", &response);
0075 }
0076 }
0077 }
0078 response *= scale;
0079 tree.Responses().push_back(response);
0080 } else {
0081 int thisidx = tree.CutIndices().size();
0082
0083 int selector = 0;
0084 float cutval = 0.;
0085 bool ctype = false;
0086
0087 node->QueryIntAttribute("IVar", &selector);
0088 node->QueryFloatAttribute("Cut", &cutval);
0089 node->QueryBoolAttribute("cType", &ctype);
0090
0091 tree.CutIndices().push_back(static_cast<unsigned char>(selector));
0092
0093
0094
0095 if (adjustboundary) {
0096 cutval = std::nextafter(cutval, std::numeric_limits<float>::lowest());
0097 }
0098 tree.CutVals().push_back(cutval);
0099 tree.LeftIndices().push_back(0);
0100 tree.RightIndices().push_back(0);
0101
0102 tinyxml2::XMLElement* left = nullptr;
0103 tinyxml2::XMLElement* right = nullptr;
0104 for (tinyxml2::XMLElement* e = node->FirstChildElement("Node"); e != nullptr; e = e->NextSiblingElement("Node")) {
0105 if (*(e->Attribute("pos")) == 'l')
0106 left = e;
0107 else if (*(e->Attribute("pos")) == 'r')
0108 right = e;
0109 }
0110 if (!ctype) {
0111 std::swap(left, right);
0112 }
0113
0114 tree.LeftIndices()[thisidx] = isTerminal(left) ? -tree.Responses().size() : tree.CutIndices().size();
0115 addNode(tree, left, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
0116
0117 tree.RightIndices()[thisidx] = isTerminal(right) ? -tree.Responses().size() : tree.CutIndices().size();
0118 addNode(tree, right, scale, isRegression, useYesNoLeaf, adjustboundary, isAdaClassifier);
0119 }
0120 }
0121
0122 std::unique_ptr<GBRForest> init(const std::string& weightsFileFullPath, std::vector<std::string>& varNames) {
0123
0124
0125
0126 if (reco::details::hasEnding(weightsFileFullPath, ".root")) {
0127 TFile gbrForestFile(weightsFileFullPath.c_str());
0128 std::unique_ptr<GBRForest> up(gbrForestFile.Get<GBRForest>("gbrForest"));
0129 std::unique_ptr<std::vector<std::string>> vars(gbrForestFile.Get<std::vector<std::string>>("variableNames"));
0130 gbrForestFile.Close("nodelete");
0131 if (vars) {
0132 varNames = std::move(*vars);
0133 }
0134 return up;
0135 }
0136
0137
0138
0139
0140 tinyxml2::XMLDocument xmlDoc;
0141
0142 using namespace reco::details;
0143
0144 if (hasEnding(weightsFileFullPath, ".xml")) {
0145 xmlDoc.LoadFile(weightsFileFullPath.c_str());
0146 } else if (hasEnding(weightsFileFullPath, ".gz") || hasEnding(weightsFileFullPath, ".gzip")) {
0147 char* buffer = readGzipFile(weightsFileFullPath);
0148 xmlDoc.Parse(buffer);
0149 free(buffer);
0150 }
0151
0152 tinyxml2::XMLElement* root = xmlDoc.FirstChildElement("MethodSetup");
0153 readVariables(root->FirstChildElement("Variables"), "Variable", varNames);
0154
0155
0156 std::map<std::string, std::string> info;
0157 tinyxml2::XMLElement* infoElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("GeneralInfo");
0158 if (infoElem == nullptr) {
0159 throw cms::Exception("XMLError") << "No GeneralInfo found in " << weightsFileFullPath << " !!\n";
0160 }
0161 for (tinyxml2::XMLElement* e = infoElem->FirstChildElement("Info"); e != nullptr;
0162 e = e->NextSiblingElement("Info")) {
0163 const char* name;
0164 const char* value;
0165 if (tinyxml2::XML_SUCCESS != e->QueryStringAttribute("name", &name)) {
0166 throw cms::Exception("XMLERROR") << "no 'name' attribute found in 'Info' element in " << weightsFileFullPath;
0167 }
0168 if (tinyxml2::XML_SUCCESS != e->QueryStringAttribute("value", &value)) {
0169 throw cms::Exception("XMLERROR") << "no 'value' attribute found in 'Info' element in " << weightsFileFullPath;
0170 }
0171 info[name] = value;
0172 }
0173
0174
0175 std::map<std::string, std::string> options;
0176 tinyxml2::XMLElement* optionsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Options");
0177 if (optionsElem == nullptr) {
0178 throw cms::Exception("XMLError") << "No Options found in " << weightsFileFullPath << " !!\n";
0179 }
0180 for (tinyxml2::XMLElement* e = optionsElem->FirstChildElement("Option"); e != nullptr;
0181 e = e->NextSiblingElement("Option")) {
0182 const char* name;
0183 if (tinyxml2::XML_SUCCESS != e->QueryStringAttribute("name", &name)) {
0184 throw cms::Exception("XMLERROR") << "no 'name' attribute found in 'Option' element in " << weightsFileFullPath;
0185 }
0186 options[name] = e->GetText();
0187 }
0188
0189
0190 int rootTrainingVersion(0);
0191 if (info.find("ROOT Release") != info.end()) {
0192 std::string s = info["ROOT Release"];
0193 rootTrainingVersion = std::stoi(s.substr(s.find('[') + 1, s.find(']') - s.find('[') - 1));
0194 }
0195
0196
0197 std::vector<double> boostWeights;
0198 tinyxml2::XMLElement* weightsElem = xmlDoc.FirstChildElement("MethodSetup")->FirstChildElement("Weights");
0199 if (weightsElem == nullptr) {
0200 throw cms::Exception("XMLError") << "No Weights found in " << weightsFileFullPath << " !!\n";
0201 }
0202 bool hasTrees = false;
0203 for (tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree"); e != nullptr;
0204 e = e->NextSiblingElement("BinaryTree")) {
0205 hasTrees = true;
0206 double w;
0207 if (tinyxml2::XML_SUCCESS != e->QueryDoubleAttribute("boostWeight", &w)) {
0208 throw cms::Exception("XMLERROR") << "problem with 'boostWeight' attribute found in 'BinaryTree' element in "
0209 << weightsFileFullPath;
0210 }
0211 boostWeights.push_back(w);
0212 }
0213 if (!hasTrees) {
0214 throw cms::Exception("XMLError") << "No BinaryTrees found in " << weightsFileFullPath << " !!\n";
0215 }
0216
0217 bool isRegression = info["AnalysisType"] == "Regression";
0218
0219
0220
0221 bool isAdaClassifier = !isRegression && options["BoostType"] != "Grad";
0222 bool useYesNoLeaf = isAdaClassifier && options["UseYesNoLeaf"] == "True";
0223
0224
0225
0226 bool adjustBoundaries =
0227 (rootTrainingVersion >= ROOT_VERSION(5, 34, 20) && rootTrainingVersion < ROOT_VERSION(6, 0, 0)) ||
0228 rootTrainingVersion >= ROOT_VERSION(6, 2, 0);
0229
0230 auto forest = std::make_unique<GBRForest>();
0231 forest->SetInitialResponse(isRegression ? boostWeights[0] : 0.);
0232
0233 double norm = 0;
0234 if (isAdaClassifier) {
0235 for (double w : boostWeights) {
0236 norm += w;
0237 }
0238 }
0239
0240 forest->Trees().reserve(boostWeights.size());
0241 size_t itree = 0;
0242
0243 for (tinyxml2::XMLElement* e = weightsElem->FirstChildElement("BinaryTree"); e != nullptr;
0244 e = e->NextSiblingElement("BinaryTree")) {
0245 double scale = isAdaClassifier ? boostWeights[itree] / norm : 1.0;
0246
0247 tinyxml2::XMLElement* root = e->FirstChildElement("Node");
0248 forest->Trees().push_back(GBRTree(countIntermediateNodes(root), countTerminalNodes(root)));
0249 auto& tree = forest->Trees().back();
0250
0251 addNode(tree, root, scale, isRegression, useYesNoLeaf, adjustBoundaries, isAdaClassifier);
0252
0253
0254 if (tree.CutIndices().empty()) {
0255 tree.CutIndices().push_back(0);
0256 tree.CutVals().push_back(0);
0257 tree.LeftIndices().push_back(0);
0258 tree.RightIndices().push_back(0);
0259 }
0260
0261 ++itree;
0262 }
0263
0264 return forest;
0265 }
0266
0267 }
0268
0269
0270 std::unique_ptr<const GBRForest> createGBRForest(const std::string& weightsFile) {
0271 std::vector<std::string> varNames;
0272 return createGBRForest(weightsFile, varNames);
0273 }
0274
0275 std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath& weightsFile) {
0276 std::vector<std::string> varNames;
0277 return createGBRForest(weightsFile.fullPath(), varNames);
0278 }
0279
0280
0281 std::unique_ptr<const GBRForest> createGBRForest(const std::string& weightsFile, std::vector<std::string>& varNames) {
0282 std::unique_ptr<GBRForest> gbrForest;
0283
0284 if (weightsFile[0] == '/') {
0285 gbrForest = init(weightsFile, varNames);
0286 } else {
0287 edm::FileInPath weightsFileEdm(weightsFile);
0288 gbrForest = init(weightsFileEdm.fullPath(), varNames);
0289 }
0290 return gbrForest;
0291 }
0292
0293 std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath& weightsFile,
0294 std::vector<std::string>& varNames) {
0295 return createGBRForest(weightsFile.fullPath(), varNames);
0296 }