File indexing completed on 2024-04-06 12:20:55
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 #include "L1Trigger/L1TMuonEndCap/interface/bdt/Tree.h"
0019
0020 #include <iostream>
0021 #include <sstream>
0022 #include <cmath>
0023
0024
0025
0026
0027
0028 using namespace emtf;
0029
0030 Tree::Tree() {
0031 rootNode = new Node("root");
0032
0033 terminalNodes.push_back(rootNode);
0034 numTerminalNodes = 1;
0035 boostWeight = 0;
0036 xmlVersion = 2017;
0037 }
0038
0039 Tree::Tree(std::vector<std::vector<Event*>>& cEvents) {
0040 rootNode = new Node("root");
0041 rootNode->setEvents(cEvents);
0042
0043 terminalNodes.push_back(rootNode);
0044 numTerminalNodes = 1;
0045 boostWeight = 0;
0046 xmlVersion = 2017;
0047 }
0048
0049
0050
0051
0052 Tree::~Tree() {
0053
0054
0055 if (rootNode)
0056 delete rootNode;
0057 }
0058
0059 Tree::Tree(const Tree& tree) {
0060
0061 rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
0062 numTerminalNodes = tree.numTerminalNodes;
0063 rmsError = tree.rmsError;
0064 boostWeight = tree.boostWeight;
0065 xmlVersion = tree.xmlVersion;
0066
0067 terminalNodes.resize(0);
0068
0069 findLeafs(rootNode, terminalNodes);
0070
0071
0072 }
0073
0074 Tree& Tree::operator=(const Tree& tree) {
0075 if (rootNode)
0076 delete rootNode;
0077
0078 rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
0079 numTerminalNodes = tree.numTerminalNodes;
0080 rmsError = tree.rmsError;
0081 boostWeight = tree.boostWeight;
0082 xmlVersion = tree.xmlVersion;
0083
0084 terminalNodes.resize(0);
0085
0086 findLeafs(rootNode, terminalNodes);
0087
0088
0089
0090 return *this;
0091 }
0092
0093 Node* Tree::copyFrom(const Node* local_root) {
0094
0095 if (!local_root)
0096 return nullptr;
0097
0098 Node* lr = const_cast<Node*>(local_root);
0099
0100
0101 Node* left_new_child = copyFrom(lr->getLeftDaughter());
0102 Node* right_new_child = copyFrom(lr->getRightDaughter());
0103
0104
0105 Node* new_local_root = new Node(lr->getName());
0106 if (left_new_child)
0107 left_new_child->setParent(new_local_root);
0108 if (right_new_child)
0109 right_new_child->setParent(new_local_root);
0110 new_local_root->setLeftDaughter(left_new_child);
0111 new_local_root->setRightDaughter(right_new_child);
0112 new_local_root->setErrorReduction(lr->getErrorReduction());
0113 new_local_root->setSplitValue(lr->getSplitValue());
0114 new_local_root->setSplitVariable(lr->getSplitVariable());
0115 new_local_root->setFitValue(lr->getFitValue());
0116 new_local_root->setTotalError(lr->getTotalError());
0117 new_local_root->setAvgError(lr->getAvgError());
0118 new_local_root->setNumEvents(lr->getNumEvents());
0119
0120
0121 return new_local_root;
0122 }
0123
0124 void Tree::findLeafs(Node* local_root, std::list<Node*>& tn) {
0125 if (!local_root->getLeftDaughter() && !local_root->getRightDaughter()) {
0126
0127 tn.push_back(local_root);
0128 return;
0129 }
0130
0131 if (local_root->getLeftDaughter())
0132 findLeafs(local_root->getLeftDaughter(), tn);
0133
0134 if (local_root->getRightDaughter())
0135 findLeafs(local_root->getRightDaughter(), tn);
0136 }
0137
0138 Tree::Tree(Tree&& tree) {
0139 if (rootNode)
0140 delete rootNode;
0141 rootNode = tree.rootNode;
0142 terminalNodes = std::move(tree.terminalNodes);
0143 numTerminalNodes = tree.numTerminalNodes;
0144 rmsError = tree.rmsError;
0145 boostWeight = tree.boostWeight;
0146 xmlVersion = tree.xmlVersion;
0147 }
0148
0149
0150
0151
0152
0153 void Tree::setRootNode(Node* sRootNode) { rootNode = sRootNode; }
0154
0155 Node* Tree::getRootNode() { return rootNode; }
0156
0157
0158
0159 void Tree::setTerminalNodes(std::list<Node*>& sTNodes) { terminalNodes = sTNodes; }
0160
0161 std::list<Node*>& Tree::getTerminalNodes() { return terminalNodes; }
0162
0163
0164
0165 int Tree::getNumTerminalNodes() { return numTerminalNodes; }
0166
0167
0168
0169
0170
0171 void Tree::calcError() {
0172
0173
0174
0175 double totalSquaredError = 0;
0176
0177 for (std::list<Node*>::iterator it = terminalNodes.begin(); it != terminalNodes.end(); it++) {
0178 totalSquaredError += (*it)->getTotalError();
0179 }
0180 rmsError = sqrt(totalSquaredError / rootNode->getNumEvents());
0181 }
0182
0183
0184
0185 void Tree::buildTree(int nodeLimit) {
0186
0187 double bestNodeErrorReduction = -1;
0188 Node* nodeToSplit = nullptr;
0189
0190 if (numTerminalNodes == 1) {
0191 rootNode->calcOptimumSplit();
0192 calcError();
0193
0194 }
0195
0196 for (std::list<Node*>::iterator it = terminalNodes.begin(); it != terminalNodes.end(); it++) {
0197 if ((*it)->getErrorReduction() > bestNodeErrorReduction) {
0198 bestNodeErrorReduction = (*it)->getErrorReduction();
0199 nodeToSplit = (*it);
0200 }
0201 }
0202
0203
0204
0205
0206 if (nodeToSplit == nullptr)
0207 return;
0208
0209
0210 nodeToSplit->theMiracleOfChildBirth();
0211
0212
0213 Node* left = nodeToSplit->getLeftDaughter();
0214 Node* right = nodeToSplit->getRightDaughter();
0215
0216
0217 terminalNodes.remove(nodeToSplit);
0218 terminalNodes.push_back(left);
0219 terminalNodes.push_back(right);
0220 numTerminalNodes++;
0221
0222
0223 nodeToSplit->filterEventsToDaughters();
0224
0225
0226 left->calcOptimumSplit();
0227 right->calcOptimumSplit();
0228
0229
0230 calcError();
0231
0232 if (numTerminalNodes % 1 == 0) {
0233
0234 }
0235
0236
0237 if (numTerminalNodes < nodeLimit)
0238 buildTree(nodeLimit);
0239 }
0240
0241
0242
0243 void Tree::filterEvents(std::vector<Event*>& tEvents) {
0244
0245
0246
0247
0248 rootNode->getEvents() = std::vector<std::vector<Event*>>(1);
0249 rootNode->getEvents()[0] = tEvents;
0250
0251
0252
0253 filterEventsRecursive(rootNode);
0254 }
0255
0256
0257
0258 void Tree::filterEventsRecursive(Node* node) {
0259
0260
0261
0262 Node* left = node->getLeftDaughter();
0263 Node* right = node->getRightDaughter();
0264
0265 if (left == nullptr || right == nullptr)
0266 return;
0267
0268 node->filterEventsToDaughters();
0269
0270 filterEventsRecursive(left);
0271 filterEventsRecursive(right);
0272 }
0273
0274
0275
0276 Node* Tree::filterEvent(Event* e) {
0277
0278
0279
0280
0281 Node* node = filterEventRecursive(rootNode, e);
0282 return node;
0283 }
0284
0285
0286
0287 Node* Tree::filterEventRecursive(Node* node, Event* e) {
0288
0289
0290
0291 Node* nextNode = node->filterEventToDaughter(e);
0292 if (nextNode == nullptr)
0293 return node;
0294
0295 return filterEventRecursive(nextNode, e);
0296 }
0297
0298
0299
0300 void Tree::rankVariablesRecursive(Node* node, std::vector<double>& v) {
0301
0302
0303
0304
0305 Node* left = node->getLeftDaughter();
0306 Node* right = node->getRightDaughter();
0307
0308
0309 if (left == nullptr || right == nullptr)
0310 return;
0311
0312 int sv = node->getSplitVariable();
0313 double er = node->getErrorReduction();
0314
0315
0316
0317
0318
0319
0320
0321
0322
0323
0324 v[sv] += er;
0325
0326 rankVariablesRecursive(left, v);
0327 rankVariablesRecursive(right, v);
0328 }
0329
0330
0331
0332 void Tree::rankVariables(std::vector<double>& v) { rankVariablesRecursive(rootNode, v); }
0333
0334
0335
0336 void Tree::getSplitValuesRecursive(Node* node, std::vector<std::vector<double>>& v) {
0337
0338
0339
0340 Node* left = node->getLeftDaughter();
0341 Node* right = node->getRightDaughter();
0342
0343
0344 if (left == nullptr || right == nullptr)
0345 return;
0346
0347 int sv = node->getSplitVariable();
0348 double sp = node->getSplitValue();
0349
0350 if (sv == -1) {
0351 std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
0352 std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
0353 }
0354
0355
0356 v[sv].push_back(sp);
0357
0358 getSplitValuesRecursive(left, v);
0359 getSplitValuesRecursive(right, v);
0360 }
0361
0362
0363
0364 void Tree::getSplitValues(std::vector<std::vector<double>>& v) { getSplitValuesRecursive(rootNode, v); }
0365
0366
0367
0368
0369
0370 template <typename T>
0371 std::string numToStr(T num) {
0372
0373 std::stringstream ss;
0374 ss << num;
0375 std::string s = ss.str();
0376 return s;
0377 }
0378
0379
0380
0381 void Tree::addXMLAttributes(TXMLEngine* xml, Node* node, XMLNodePointer_t np) {
0382
0383
0384 xml->NewAttr(np, nullptr, "splitVar", numToStr(node->getSplitVariable()).c_str());
0385 xml->NewAttr(np, nullptr, "splitVal", numToStr(node->getSplitValue()).c_str());
0386 xml->NewAttr(np, nullptr, "fitVal", numToStr(node->getFitValue()).c_str());
0387 }
0388
0389
0390
0391 void Tree::saveToXML(const char* c) {
0392 TXMLEngine* xml = new TXMLEngine();
0393
0394
0395 XMLNodePointer_t root = xml->NewChild(nullptr, nullptr, rootNode->getName().c_str());
0396 addXMLAttributes(xml, rootNode, root);
0397
0398
0399 saveToXMLRecursive(xml, rootNode, root);
0400
0401
0402 XMLDocPointer_t xmldoc = xml->NewDoc();
0403 xml->DocSetRootElement(xmldoc, root);
0404
0405
0406 xml->SaveDoc(xmldoc, c);
0407
0408
0409 xml->FreeDoc(xmldoc);
0410 delete xml;
0411 }
0412
0413
0414
0415 void Tree::saveToXMLRecursive(TXMLEngine* xml, Node* node, XMLNodePointer_t np) {
0416 Node* l = node->getLeftDaughter();
0417 Node* r = node->getRightDaughter();
0418
0419 if (l == nullptr || r == nullptr)
0420 return;
0421
0422
0423 XMLNodePointer_t left = xml->NewChild(np, nullptr, "left");
0424 XMLNodePointer_t right = xml->NewChild(np, nullptr, "right");
0425
0426
0427 addXMLAttributes(xml, l, left);
0428 addXMLAttributes(xml, r, right);
0429
0430
0431 saveToXMLRecursive(xml, l, left);
0432 saveToXMLRecursive(xml, r, right);
0433 }
0434
0435
0436
0437 void Tree::loadFromXML(const char* filename) {
0438
0439 TXMLEngine* xml = new TXMLEngine;
0440
0441
0442 XMLDocPointer_t xmldoc = xml->ParseFile(filename);
0443 if (xmldoc == nullptr) {
0444 delete xml;
0445 return;
0446 }
0447
0448
0449 XMLNodePointer_t mainnode = xml->DocGetRootElement(xmldoc);
0450
0451
0452
0453
0454 if (std::string("BinaryTree") == xml->GetNodeName(mainnode)) {
0455 XMLAttrPointer_t attr = xml->GetFirstAttr(mainnode);
0456 while (std::string("boostWeight") != xml->GetAttrName(attr)) {
0457 attr = xml->GetNextAttr(attr);
0458 }
0459 boostWeight = (attr ? strtod(xml->GetAttrValue(attr), nullptr) : 0);
0460
0461 mainnode = xml->GetChild(mainnode);
0462 xmlVersion = 2017;
0463 } else {
0464 boostWeight = 0;
0465 xmlVersion = 2016;
0466 }
0467
0468 loadFromXMLRecursive(xml, mainnode, rootNode);
0469
0470
0471 xml->FreeDoc(xmldoc);
0472 delete xml;
0473 }
0474
0475
0476
0477 void Tree::loadFromXMLRecursive(TXMLEngine* xml, XMLNodePointer_t xnode, Node* tnode) {
0478
0479 XMLAttrPointer_t attr = xml->GetFirstAttr(xnode);
0480 std::vector<std::string> splitInfo(3);
0481 if (xmlVersion >= 2017) {
0482 for (unsigned int i = 0; i < 10; i++) {
0483 if (std::string("IVar") == xml->GetAttrName(attr)) {
0484 splitInfo[0] = xml->GetAttrValue(attr);
0485 }
0486 if (std::string("Cut") == xml->GetAttrName(attr)) {
0487 splitInfo[1] = xml->GetAttrValue(attr);
0488 }
0489 if (std::string("res") == xml->GetAttrName(attr)) {
0490 splitInfo[2] = xml->GetAttrValue(attr);
0491 }
0492 attr = xml->GetNextAttr(attr);
0493 }
0494 } else {
0495 for (unsigned int i = 0; i < 3; i++) {
0496 splitInfo[i] = xml->GetAttrValue(attr);
0497 attr = xml->GetNextAttr(attr);
0498 }
0499 }
0500
0501
0502 std::stringstream converter;
0503 int splitVar;
0504 double splitVal;
0505 double fitVal;
0506
0507 converter << splitInfo[0];
0508 converter >> splitVar;
0509 converter.str("");
0510 converter.clear();
0511
0512 converter << splitInfo[1];
0513 converter >> splitVal;
0514 converter.str("");
0515 converter.clear();
0516
0517 converter << splitInfo[2];
0518 converter >> fitVal;
0519 converter.str("");
0520 converter.clear();
0521
0522
0523 tnode->setSplitVariable(splitVar);
0524 tnode->setSplitValue(splitVal);
0525 tnode->setFitValue(fitVal);
0526
0527
0528 XMLNodePointer_t xleft = xml->GetChild(xnode);
0529 XMLNodePointer_t xright = xml->GetNext(xleft);
0530
0531
0532 if (xleft == nullptr || xright == nullptr)
0533 return;
0534
0535
0536 tnode->theMiracleOfChildBirth();
0537 Node* tleft = tnode->getLeftDaughter();
0538 Node* tright = tnode->getRightDaughter();
0539
0540
0541 terminalNodes.remove(tnode);
0542 terminalNodes.push_back(tleft);
0543 terminalNodes.push_back(tright);
0544 numTerminalNodes++;
0545
0546 loadFromXMLRecursive(xml, xleft, tleft);
0547 loadFromXMLRecursive(xml, xright, tright);
0548 }
0549
0550 void Tree::loadFromCondPayload(const L1TMuonEndCapForest::DTree& tree) {
0551
0552 if (rootNode)
0553 delete rootNode;
0554 rootNode = new Node("root");
0555
0556 const L1TMuonEndCapForest::DTreeNode& mainnode = tree[0];
0557 loadFromCondPayloadRecursive(tree, mainnode, rootNode);
0558 }
0559
0560 void Tree::loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree& tree,
0561 const L1TMuonEndCapForest::DTreeNode& node,
0562 Node* tnode) {
0563
0564 tnode->setSplitVariable(node.splitVar);
0565 tnode->setSplitValue(node.splitVal);
0566 tnode->setFitValue(node.fitVal);
0567
0568
0569 if (node.ileft == 0 || node.iright == 0)
0570 return;
0571 if (node.ileft >= tree.size() || node.iright >= tree.size())
0572 return;
0573
0574
0575 tnode->theMiracleOfChildBirth();
0576 Node* tleft = tnode->getLeftDaughter();
0577 Node* tright = tnode->getRightDaughter();
0578
0579
0580 terminalNodes.remove(tnode);
0581 terminalNodes.push_back(tleft);
0582 terminalNodes.push_back(tright);
0583 numTerminalNodes++;
0584
0585 loadFromCondPayloadRecursive(tree, tree[node.ileft], tleft);
0586 loadFromCondPayloadRecursive(tree, tree[node.iright], tright);
0587 }