File indexing completed on 2025-03-23 23:40:24
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 #include "L1Trigger/L1TMuonEndCap/interface/bdt/Tree.h"
0019
0020 #include <iostream>
0021 #include <sstream>
0022 #include <cmath>
0023
0024
0025
0026
0027
0028 using namespace emtf;
0029
0030 Tree::Tree() {
0031 rootNode = std::make_unique<Node>("root");
0032
0033 terminalNodes.push_back(rootNode.get());
0034 numTerminalNodes = 1;
0035 boostWeight = 0;
0036 xmlVersion = 2017;
0037 }
0038
0039 Tree::Tree(std::vector<std::vector<Event*>>& cEvents) {
0040 rootNode = std::make_unique<Node>("root");
0041 rootNode->setEvents(cEvents);
0042
0043 terminalNodes.push_back(rootNode.get());
0044 numTerminalNodes = 1;
0045 boostWeight = 0;
0046 xmlVersion = 2017;
0047 }
0048
0049 Tree::Tree(const Tree& tree) {
0050 rootNode = copyFrom(tree.getRootNode());
0051 numTerminalNodes = tree.numTerminalNodes;
0052 rmsError = tree.rmsError;
0053 boostWeight = tree.boostWeight;
0054 xmlVersion = tree.xmlVersion;
0055
0056 terminalNodes.resize(0);
0057
0058 findLeafs(rootNode.get(), terminalNodes);
0059
0060
0061 }
0062
0063 Tree& Tree::operator=(const Tree& tree) {
0064 rootNode = copyFrom(tree.getRootNode());
0065 numTerminalNodes = tree.numTerminalNodes;
0066 rmsError = tree.rmsError;
0067 boostWeight = tree.boostWeight;
0068 xmlVersion = tree.xmlVersion;
0069
0070 terminalNodes.resize(0);
0071
0072 findLeafs(rootNode.get(), terminalNodes);
0073
0074
0075
0076 return *this;
0077 }
0078
0079 std::unique_ptr<Node> Tree::copyFrom(const Node* local_root) {
0080
0081 if (!local_root)
0082 return nullptr;
0083
0084 const Node* lr = local_root;
0085
0086
0087 auto left_new_child = copyFrom(lr->getLeftDaughter());
0088 auto right_new_child = copyFrom(lr->getRightDaughter());
0089
0090
0091 auto new_local_root = std::make_unique<Node>(lr->getName());
0092 if (left_new_child)
0093 left_new_child->setParent(new_local_root.get());
0094 if (right_new_child)
0095 right_new_child->setParent(new_local_root.get());
0096 new_local_root->setLeftDaughter(std::move(left_new_child));
0097 new_local_root->setRightDaughter(std::move(right_new_child));
0098 new_local_root->setErrorReduction(lr->getErrorReduction());
0099 new_local_root->setSplitValue(lr->getSplitValue());
0100 new_local_root->setSplitVariable(lr->getSplitVariable());
0101 new_local_root->setFitValue(lr->getFitValue());
0102 new_local_root->setTotalError(lr->getTotalError());
0103 new_local_root->setAvgError(lr->getAvgError());
0104 new_local_root->setNumEvents(lr->getNumEvents());
0105
0106
0107 return new_local_root;
0108 }
0109
0110 void Tree::findLeafs(Node* local_root, std::list<Node*>& tn) {
0111 if (!local_root->getLeftDaughter() && !local_root->getRightDaughter()) {
0112
0113 tn.push_back(local_root);
0114 return;
0115 }
0116
0117 if (local_root->getLeftDaughter())
0118 findLeafs(local_root->getLeftDaughter(), tn);
0119
0120 if (local_root->getRightDaughter())
0121 findLeafs(local_root->getRightDaughter(), tn);
0122 }
0123
0124
0125
0126
0127
0128 void Tree::setRootNode(Node* sRootNode) { rootNode.reset(sRootNode); }
0129
0130 Node* Tree::getRootNode() { return rootNode.get(); }
0131 const Node* Tree::getRootNode() const { return rootNode.get(); }
0132
0133
0134
0135 void Tree::setTerminalNodes(std::list<Node*>& sTNodes) { terminalNodes = sTNodes; }
0136
0137 std::list<Node*>& Tree::getTerminalNodes() { return terminalNodes; }
0138
0139
0140
0141 int Tree::getNumTerminalNodes() { return numTerminalNodes; }
0142
0143
0144
0145
0146
0147 void Tree::calcError() {
0148
0149
0150
0151 double totalSquaredError = 0;
0152
0153 for (std::list<Node*>::iterator it = terminalNodes.begin(); it != terminalNodes.end(); it++) {
0154 totalSquaredError += (*it)->getTotalError();
0155 }
0156 rmsError = sqrt(totalSquaredError / rootNode->getNumEvents());
0157 }
0158
0159
0160
0161 void Tree::buildTree(int nodeLimit) {
0162
0163 double bestNodeErrorReduction = -1;
0164 Node* nodeToSplit = nullptr;
0165
0166 if (numTerminalNodes == 1) {
0167 rootNode->calcOptimumSplit();
0168 calcError();
0169
0170 }
0171
0172 for (std::list<Node*>::iterator it = terminalNodes.begin(); it != terminalNodes.end(); it++) {
0173 if ((*it)->getErrorReduction() > bestNodeErrorReduction) {
0174 bestNodeErrorReduction = (*it)->getErrorReduction();
0175 nodeToSplit = (*it);
0176 }
0177 }
0178
0179
0180
0181
0182 if (nodeToSplit == nullptr)
0183 return;
0184
0185
0186 nodeToSplit->theMiracleOfChildBirth();
0187
0188
0189 Node* left = nodeToSplit->getLeftDaughter();
0190 Node* right = nodeToSplit->getRightDaughter();
0191
0192
0193 terminalNodes.remove(nodeToSplit);
0194 terminalNodes.push_back(left);
0195 terminalNodes.push_back(right);
0196 numTerminalNodes++;
0197
0198
0199 nodeToSplit->filterEventsToDaughters();
0200
0201
0202 left->calcOptimumSplit();
0203 right->calcOptimumSplit();
0204
0205
0206 calcError();
0207
0208 if (numTerminalNodes % 1 == 0) {
0209
0210 }
0211
0212
0213 if (numTerminalNodes < nodeLimit)
0214 buildTree(nodeLimit);
0215 }
0216
0217
0218
0219 void Tree::filterEvents(std::vector<Event*>& tEvents) {
0220
0221
0222
0223
0224 rootNode->getEvents() = std::vector<std::vector<Event*>>(1);
0225 rootNode->getEvents()[0] = tEvents;
0226
0227
0228
0229 filterEventsRecursive(rootNode.get());
0230 }
0231
0232
0233
0234 void Tree::filterEventsRecursive(Node* node) {
0235
0236
0237
0238 Node* left = node->getLeftDaughter();
0239 Node* right = node->getRightDaughter();
0240
0241 if (left == nullptr || right == nullptr)
0242 return;
0243
0244 node->filterEventsToDaughters();
0245
0246 filterEventsRecursive(left);
0247 filterEventsRecursive(right);
0248 }
0249
0250
0251
0252 Node* Tree::filterEvent(Event* e) {
0253
0254
0255
0256
0257 Node* node = filterEventRecursive(rootNode.get(), e);
0258 return node;
0259 }
0260
0261
0262
0263 Node* Tree::filterEventRecursive(Node* node, Event* e) {
0264
0265
0266
0267 Node* nextNode = node->filterEventToDaughter(e);
0268 if (nextNode == nullptr)
0269 return node;
0270
0271 return filterEventRecursive(nextNode, e);
0272 }
0273
0274
0275
0276 void Tree::rankVariablesRecursive(Node* node, std::vector<double>& v) const {
0277
0278
0279
0280
0281 Node* left = node->getLeftDaughter();
0282 Node* right = node->getRightDaughter();
0283
0284
0285 if (left == nullptr || right == nullptr)
0286 return;
0287
0288 int sv = node->getSplitVariable();
0289 double er = node->getErrorReduction();
0290
0291
0292
0293
0294
0295
0296
0297
0298
0299
0300 v[sv] += er;
0301
0302 rankVariablesRecursive(left, v);
0303 rankVariablesRecursive(right, v);
0304 }
0305
0306
0307
0308 void Tree::rankVariables(std::vector<double>& v) const { rankVariablesRecursive(rootNode.get(), v); }
0309
0310
0311
0312 void Tree::getSplitValuesRecursive(Node* node, std::vector<std::vector<double>>& v) const {
0313
0314
0315
0316 Node* left = node->getLeftDaughter();
0317 Node* right = node->getRightDaughter();
0318
0319
0320 if (left == nullptr || right == nullptr)
0321 return;
0322
0323 int sv = node->getSplitVariable();
0324 double sp = node->getSplitValue();
0325
0326 if (sv == -1) {
0327 std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
0328 std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
0329 }
0330
0331
0332 v[sv].push_back(sp);
0333
0334 getSplitValuesRecursive(left, v);
0335 getSplitValuesRecursive(right, v);
0336 }
0337
0338
0339
0340 void Tree::getSplitValues(std::vector<std::vector<double>>& v) const { getSplitValuesRecursive(rootNode.get(), v); }
0341
0342
0343
0344
0345
0346 namespace {
0347 template <typename T>
0348 std::string numToStr(T num) {
0349
0350 std::stringstream ss;
0351 ss << num;
0352 std::string s = ss.str();
0353 return s;
0354 }
0355 }
0356
0357
0358
0359 void Tree::addXMLAttributes(TXMLEngine* xml, Node* node, XMLNodePointer_t np) {
0360
0361
0362 xml->NewAttr(np, nullptr, "splitVar", numToStr(node->getSplitVariable()).c_str());
0363 xml->NewAttr(np, nullptr, "splitVal", numToStr(node->getSplitValue()).c_str());
0364 xml->NewAttr(np, nullptr, "fitVal", numToStr(node->getFitValue()).c_str());
0365 }
0366
0367
0368
0369 void Tree::saveToXML(const char* c) {
0370 TXMLEngine* xml = new TXMLEngine();
0371
0372
0373 XMLNodePointer_t root = xml->NewChild(nullptr, nullptr, rootNode->getName().c_str());
0374 addXMLAttributes(xml, rootNode.get(), root);
0375
0376
0377 saveToXMLRecursive(xml, rootNode.get(), root);
0378
0379
0380 XMLDocPointer_t xmldoc = xml->NewDoc();
0381 xml->DocSetRootElement(xmldoc, root);
0382
0383
0384 xml->SaveDoc(xmldoc, c);
0385
0386
0387 xml->FreeDoc(xmldoc);
0388 delete xml;
0389 }
0390
0391
0392
0393 void Tree::saveToXMLRecursive(TXMLEngine* xml, Node* node, XMLNodePointer_t np) {
0394 Node* l = node->getLeftDaughter();
0395 Node* r = node->getRightDaughter();
0396
0397 if (l == nullptr || r == nullptr)
0398 return;
0399
0400
0401 XMLNodePointer_t left = xml->NewChild(np, nullptr, "left");
0402 XMLNodePointer_t right = xml->NewChild(np, nullptr, "right");
0403
0404
0405 addXMLAttributes(xml, l, left);
0406 addXMLAttributes(xml, r, right);
0407
0408
0409 saveToXMLRecursive(xml, l, left);
0410 saveToXMLRecursive(xml, r, right);
0411 }
0412
0413
0414
0415 void Tree::loadFromXML(const char* filename) {
0416
0417 TXMLEngine* xml = new TXMLEngine;
0418
0419
0420 XMLDocPointer_t xmldoc = xml->ParseFile(filename);
0421 if (xmldoc == nullptr) {
0422 delete xml;
0423 return;
0424 }
0425
0426
0427 XMLNodePointer_t mainnode = xml->DocGetRootElement(xmldoc);
0428
0429
0430
0431
0432 if (std::string("BinaryTree") == xml->GetNodeName(mainnode)) {
0433 XMLAttrPointer_t attr = xml->GetFirstAttr(mainnode);
0434 while (std::string("boostWeight") != xml->GetAttrName(attr)) {
0435 attr = xml->GetNextAttr(attr);
0436 }
0437 boostWeight = (attr ? strtod(xml->GetAttrValue(attr), nullptr) : 0);
0438
0439 mainnode = xml->GetChild(mainnode);
0440 xmlVersion = 2017;
0441 } else {
0442 boostWeight = 0;
0443 xmlVersion = 2016;
0444 }
0445
0446 loadFromXMLRecursive(xml, mainnode, rootNode.get());
0447
0448
0449 xml->FreeDoc(xmldoc);
0450 delete xml;
0451 }
0452
0453
0454
0455 void Tree::loadFromXMLRecursive(TXMLEngine* xml, XMLNodePointer_t xnode, Node* tnode) {
0456
0457 XMLAttrPointer_t attr = xml->GetFirstAttr(xnode);
0458 std::vector<std::string> splitInfo(3);
0459 if (xmlVersion >= 2017) {
0460 for (unsigned int i = 0; i < 10; i++) {
0461 if (std::string("IVar") == xml->GetAttrName(attr)) {
0462 splitInfo[0] = xml->GetAttrValue(attr);
0463 }
0464 if (std::string("Cut") == xml->GetAttrName(attr)) {
0465 splitInfo[1] = xml->GetAttrValue(attr);
0466 }
0467 if (std::string("res") == xml->GetAttrName(attr)) {
0468 splitInfo[2] = xml->GetAttrValue(attr);
0469 }
0470 attr = xml->GetNextAttr(attr);
0471 }
0472 } else {
0473 for (unsigned int i = 0; i < 3; i++) {
0474 splitInfo[i] = xml->GetAttrValue(attr);
0475 attr = xml->GetNextAttr(attr);
0476 }
0477 }
0478
0479
0480 std::stringstream converter;
0481 int splitVar;
0482 double splitVal;
0483 double fitVal;
0484
0485 converter << splitInfo[0];
0486 converter >> splitVar;
0487 converter.str("");
0488 converter.clear();
0489
0490 converter << splitInfo[1];
0491 converter >> splitVal;
0492 converter.str("");
0493 converter.clear();
0494
0495 converter << splitInfo[2];
0496 converter >> fitVal;
0497 converter.str("");
0498 converter.clear();
0499
0500
0501 tnode->setSplitVariable(splitVar);
0502 tnode->setSplitValue(splitVal);
0503 tnode->setFitValue(fitVal);
0504
0505
0506 XMLNodePointer_t xleft = xml->GetChild(xnode);
0507 XMLNodePointer_t xright = xml->GetNext(xleft);
0508
0509
0510 if (xleft == nullptr || xright == nullptr)
0511 return;
0512
0513
0514 tnode->theMiracleOfChildBirth();
0515 Node* tleft = tnode->getLeftDaughter();
0516 Node* tright = tnode->getRightDaughter();
0517
0518
0519 terminalNodes.remove(tnode);
0520 terminalNodes.push_back(tleft);
0521 terminalNodes.push_back(tright);
0522 numTerminalNodes++;
0523
0524 loadFromXMLRecursive(xml, xleft, tleft);
0525 loadFromXMLRecursive(xml, xright, tright);
0526 }
0527
0528 void Tree::loadFromCondPayload(const L1TMuonEndCapForest::DTree& tree) {
0529
0530 rootNode = std::make_unique<Node>("root");
0531
0532 const L1TMuonEndCapForest::DTreeNode& mainnode = tree[0];
0533 loadFromCondPayloadRecursive(tree, mainnode, rootNode.get());
0534 }
0535
0536 void Tree::loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree& tree,
0537 const L1TMuonEndCapForest::DTreeNode& node,
0538 Node* tnode) {
0539
0540 tnode->setSplitVariable(node.splitVar);
0541 tnode->setSplitValue(node.splitVal);
0542 tnode->setFitValue(node.fitVal);
0543
0544
0545 if (node.ileft == 0 || node.iright == 0)
0546 return;
0547 if (node.ileft >= tree.size() || node.iright >= tree.size())
0548 return;
0549
0550
0551 tnode->theMiracleOfChildBirth();
0552 Node* tleft = tnode->getLeftDaughter();
0553 Node* tright = tnode->getRightDaughter();
0554
0555
0556 terminalNodes.remove(tnode);
0557 terminalNodes.push_back(tleft);
0558 terminalNodes.push_back(tright);
0559 numTerminalNodes++;
0560
0561 loadFromCondPayloadRecursive(tree, tree[node.ileft], tleft);
0562 loadFromCondPayloadRecursive(tree, tree[node.iright], tright);
0563 }