Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-03-23 23:40:24

0001 //////////////////////////////////////////////////////////////////////////
0002 //                            Tree.cxx                                  //
0003 // =====================================================================//
0004 // This is the object implementation of a decision tree.                //
0005 // References include                                                   //
0006 //    *Elements of Statistical Learning by Hastie,                      //
0007 //     Tibshirani, and Friedman.                                        //
0008 //    *Greedy Function Approximation: A Gradient Boosting Machine.      //
0009 //     Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001.    //
0010 //    *Inductive Learning of Tree-based Regression Models. Luis Torgo.  //
0011 //                                                                      //
0012 //////////////////////////////////////////////////////////////////////////
0013 
0014 ///////////////////////////////////////////////////////////////////////////
0015 // _______________________Includes_______________________________________//
0016 ///////////////////////////////////////////////////////////////////////////
0017 
0018 #include "L1Trigger/L1TMuonEndCap/interface/bdt/Tree.h"
0019 
0020 #include <iostream>
0021 #include <sstream>
0022 #include <cmath>
0023 
0024 //////////////////////////////////////////////////////////////////////////
0025 // _______________________Constructor(s)________________________________//
0026 //////////////////////////////////////////////////////////////////////////
0027 
0028 using namespace emtf;
0029 
0030 Tree::Tree() {
0031   rootNode = std::make_unique<Node>("root");
0032 
0033   terminalNodes.push_back(rootNode.get());
0034   numTerminalNodes = 1;
0035   boostWeight = 0;
0036   xmlVersion = 2017;
0037 }
0038 
0039 Tree::Tree(std::vector<std::vector<Event*>>& cEvents) {
0040   rootNode = std::make_unique<Node>("root");
0041   rootNode->setEvents(cEvents);
0042 
0043   terminalNodes.push_back(rootNode.get());
0044   numTerminalNodes = 1;
0045   boostWeight = 0;
0046   xmlVersion = 2017;
0047 }
0048 
0049 Tree::Tree(const Tree& tree) {
0050   rootNode = copyFrom(tree.getRootNode());
0051   numTerminalNodes = tree.numTerminalNodes;
0052   rmsError = tree.rmsError;
0053   boostWeight = tree.boostWeight;
0054   xmlVersion = tree.xmlVersion;
0055 
0056   terminalNodes.resize(0);
0057   // find new leafs
0058   findLeafs(rootNode.get(), terminalNodes);
0059 
0060   ///    if( numTerminalNodes != terminalNodes.size() ) throw std::runtime_error();
0061 }
0062 
0063 Tree& Tree::operator=(const Tree& tree) {
0064   rootNode = copyFrom(tree.getRootNode());
0065   numTerminalNodes = tree.numTerminalNodes;
0066   rmsError = tree.rmsError;
0067   boostWeight = tree.boostWeight;
0068   xmlVersion = tree.xmlVersion;
0069 
0070   terminalNodes.resize(0);
0071   // find new leafs
0072   findLeafs(rootNode.get(), terminalNodes);
0073 
0074   ///    if( numTerminalNodes != terminalNodes.size() ) throw std::runtime_error();
0075 
0076   return *this;
0077 }
0078 
0079 std::unique_ptr<Node> Tree::copyFrom(const Node* local_root) {
0080   // end-case
0081   if (!local_root)
0082     return nullptr;
0083 
0084   const Node* lr = local_root;
0085 
0086   // recursion
0087   auto left_new_child = copyFrom(lr->getLeftDaughter());
0088   auto right_new_child = copyFrom(lr->getRightDaughter());
0089 
0090   // performing main work at this level
0091   auto new_local_root = std::make_unique<Node>(lr->getName());
0092   if (left_new_child)
0093     left_new_child->setParent(new_local_root.get());
0094   if (right_new_child)
0095     right_new_child->setParent(new_local_root.get());
0096   new_local_root->setLeftDaughter(std::move(left_new_child));
0097   new_local_root->setRightDaughter(std::move(right_new_child));
0098   new_local_root->setErrorReduction(lr->getErrorReduction());
0099   new_local_root->setSplitValue(lr->getSplitValue());
0100   new_local_root->setSplitVariable(lr->getSplitVariable());
0101   new_local_root->setFitValue(lr->getFitValue());
0102   new_local_root->setTotalError(lr->getTotalError());
0103   new_local_root->setAvgError(lr->getAvgError());
0104   new_local_root->setNumEvents(lr->getNumEvents());
0105   //    new_local_root->setEvents( lr->getEvents() ); // no ownership assumed for the events anyways
0106 
0107   return new_local_root;
0108 }
0109 
0110 void Tree::findLeafs(Node* local_root, std::list<Node*>& tn) {
0111   if (!local_root->getLeftDaughter() && !local_root->getRightDaughter()) {
0112     // leaf or ternimal node found
0113     tn.push_back(local_root);
0114     return;
0115   }
0116 
0117   if (local_root->getLeftDaughter())
0118     findLeafs(local_root->getLeftDaughter(), tn);
0119 
0120   if (local_root->getRightDaughter())
0121     findLeafs(local_root->getRightDaughter(), tn);
0122 }
0123 
0124 //////////////////////////////////////////////////////////////////////////
0125 // ______________________Get/Set________________________________________//
0126 //////////////////////////////////////////////////////////////////////////
0127 
0128 void Tree::setRootNode(Node* sRootNode) { rootNode.reset(sRootNode); }
0129 
0130 Node* Tree::getRootNode() { return rootNode.get(); }
0131 const Node* Tree::getRootNode() const { return rootNode.get(); }
0132 
0133 // ----------------------------------------------------------------------
0134 
0135 void Tree::setTerminalNodes(std::list<Node*>& sTNodes) { terminalNodes = sTNodes; }
0136 
0137 std::list<Node*>& Tree::getTerminalNodes() { return terminalNodes; }
0138 
0139 // ----------------------------------------------------------------------
0140 
0141 int Tree::getNumTerminalNodes() { return numTerminalNodes; }
0142 
0143 //////////////////////////////////////////////////////////////////////////
0144 // ______________________Performace_____________________________________//
0145 //////////////////////////////////////////////////////////////////////////
0146 
0147 void Tree::calcError() {
0148   // Loop through the separate predictive regions (terminal nodes) and
0149   // add up the errors to get the error of the entire space.
0150 
0151   double totalSquaredError = 0;
0152 
0153   for (std::list<Node*>::iterator it = terminalNodes.begin(); it != terminalNodes.end(); it++) {
0154     totalSquaredError += (*it)->getTotalError();
0155   }
0156   rmsError = sqrt(totalSquaredError / rootNode->getNumEvents());
0157 }
0158 
0159 // ----------------------------------------------------------------------
0160 
0161 void Tree::buildTree(int nodeLimit) {
0162   // We greedily pick the best terminal node to split.
0163   double bestNodeErrorReduction = -1;
0164   Node* nodeToSplit = nullptr;
0165 
0166   if (numTerminalNodes == 1) {
0167     rootNode->calcOptimumSplit();
0168     calcError();
0169     //        std::cout << std::endl << "  " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
0170   }
0171 
0172   for (std::list<Node*>::iterator it = terminalNodes.begin(); it != terminalNodes.end(); it++) {
0173     if ((*it)->getErrorReduction() > bestNodeErrorReduction) {
0174       bestNodeErrorReduction = (*it)->getErrorReduction();
0175       nodeToSplit = (*it);
0176     }
0177   }
0178 
0179   //std::cout << "nodeToSplit size = " << nodeToSplit->getNumEvents() << std::endl;
0180 
0181   // If all of the nodes have one event we can't add any more nodes and reduce the error.
0182   if (nodeToSplit == nullptr)
0183     return;
0184 
0185   // Create daughter nodes, and link the nodes together appropriately.
0186   nodeToSplit->theMiracleOfChildBirth();
0187 
0188   // Get left and right daughters for reference.
0189   Node* left = nodeToSplit->getLeftDaughter();
0190   Node* right = nodeToSplit->getRightDaughter();
0191 
0192   // Update the list of terminal nodes.
0193   terminalNodes.remove(nodeToSplit);
0194   terminalNodes.push_back(left);
0195   terminalNodes.push_back(right);
0196   numTerminalNodes++;
0197 
0198   // Filter the events from the parent into the daughters.
0199   nodeToSplit->filterEventsToDaughters();
0200 
0201   // Calculate the best splits for the new nodes.
0202   left->calcOptimumSplit();
0203   right->calcOptimumSplit();
0204 
0205   // See if the error reduces as we add more nodes.
0206   calcError();
0207 
0208   if (numTerminalNodes % 1 == 0) {
0209     //        std::cout << "  " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
0210   }
0211 
0212   // Repeat until done.
0213   if (numTerminalNodes < nodeLimit)
0214     buildTree(nodeLimit);
0215 }
0216 
0217 // ----------------------------------------------------------------------
0218 
0219 void Tree::filterEvents(std::vector<Event*>& tEvents) {
0220   // Use trees which have already been built to fit a bunch of events
0221   // given by the tEvents vector.
0222 
0223   // Set the events to be filtered.
0224   rootNode->getEvents() = std::vector<std::vector<Event*>>(1);
0225   rootNode->getEvents()[0] = tEvents;
0226 
0227   // The tree now knows about the events it needs to fit.
0228   // Filter them into a predictive region (terminal node).
0229   filterEventsRecursive(rootNode.get());
0230 }
0231 
0232 // ----------------------------------------------------------------------
0233 
0234 void Tree::filterEventsRecursive(Node* node) {
0235   // Filter the events repeatedly into the daughter nodes until they
0236   // fall into a terminal node.
0237 
0238   Node* left = node->getLeftDaughter();
0239   Node* right = node->getRightDaughter();
0240 
0241   if (left == nullptr || right == nullptr)
0242     return;
0243 
0244   node->filterEventsToDaughters();
0245 
0246   filterEventsRecursive(left);
0247   filterEventsRecursive(right);
0248 }
0249 
0250 // ----------------------------------------------------------------------
0251 
0252 Node* Tree::filterEvent(Event* e) {
0253   // Use trees which have already been built to fit a bunch of events
0254   // given by the tEvents vector.
0255 
0256   // Filter the event into a predictive region (terminal node).
0257   Node* node = filterEventRecursive(rootNode.get(), e);
0258   return node;
0259 }
0260 
0261 // ----------------------------------------------------------------------
0262 
0263 Node* Tree::filterEventRecursive(Node* node, Event* e) {
0264   // Filter the event repeatedly into the daughter nodes until it
0265   // falls into a terminal node.
0266 
0267   Node* nextNode = node->filterEventToDaughter(e);
0268   if (nextNode == nullptr)
0269     return node;
0270 
0271   return filterEventRecursive(nextNode, e);
0272 }
0273 
0274 // ----------------------------------------------------------------------
0275 
0276 void Tree::rankVariablesRecursive(Node* node, std::vector<double>& v) const {
0277   // We recursively go through all of the nodes in the tree and find the
0278   // total error reduction for each variable. The one with the most
0279   // error reduction should be the most important.
0280 
0281   Node* left = node->getLeftDaughter();
0282   Node* right = node->getRightDaughter();
0283 
0284   // Terminal nodes don't contribute to error reduction.
0285   if (left == nullptr || right == nullptr)
0286     return;
0287 
0288   int sv = node->getSplitVariable();
0289   double er = node->getErrorReduction();
0290 
0291   //if(sv == -1)
0292   //{
0293   //std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
0294   //std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
0295   //std::cout << "rankVarRecursive Error Reduction = " << er << std::endl;
0296   //}
0297 
0298   // Add error reduction to the current total for the appropriate
0299   // variable.
0300   v[sv] += er;
0301 
0302   rankVariablesRecursive(left, v);
0303   rankVariablesRecursive(right, v);
0304 }
0305 
0306 // ----------------------------------------------------------------------
0307 
0308 void Tree::rankVariables(std::vector<double>& v) const { rankVariablesRecursive(rootNode.get(), v); }
0309 
0310 // ----------------------------------------------------------------------
0311 
0312 void Tree::getSplitValuesRecursive(Node* node, std::vector<std::vector<double>>& v) const {
0313   // We recursively go through all of the nodes in the tree and find the
0314   // split points used for each split variable.
0315 
0316   Node* left = node->getLeftDaughter();
0317   Node* right = node->getRightDaughter();
0318 
0319   // Terminal nodes don't contribute.
0320   if (left == nullptr || right == nullptr)
0321     return;
0322 
0323   int sv = node->getSplitVariable();
0324   double sp = node->getSplitValue();
0325 
0326   if (sv == -1) {
0327     std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
0328     std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
0329   }
0330 
0331   // Add the split point to the list for the correct split variable.
0332   v[sv].push_back(sp);
0333 
0334   getSplitValuesRecursive(left, v);
0335   getSplitValuesRecursive(right, v);
0336 }
0337 
0338 // ----------------------------------------------------------------------
0339 
0340 void Tree::getSplitValues(std::vector<std::vector<double>>& v) const { getSplitValuesRecursive(rootNode.get(), v); }
0341 
0342 //////////////////////////////////////////////////////////////////////////
0343 // ______________________Storage/Retrieval______________________________//
0344 //////////////////////////////////////////////////////////////////////////
0345 
0346 namespace {
0347   template <typename T>
0348   std::string numToStr(T num) {
0349     // Convert a number to a string.
0350     std::stringstream ss;
0351     ss << num;
0352     std::string s = ss.str();
0353     return s;
0354   }
0355 }  // namespace
0356 
0357 // ----------------------------------------------------------------------
0358 
0359 void Tree::addXMLAttributes(TXMLEngine* xml, Node* node, XMLNodePointer_t np) {
0360   // Convert Node members into XML attributes
0361   // and add them to the XMLEngine.
0362   xml->NewAttr(np, nullptr, "splitVar", numToStr(node->getSplitVariable()).c_str());
0363   xml->NewAttr(np, nullptr, "splitVal", numToStr(node->getSplitValue()).c_str());
0364   xml->NewAttr(np, nullptr, "fitVal", numToStr(node->getFitValue()).c_str());
0365 }
0366 
0367 // ----------------------------------------------------------------------
0368 
0369 void Tree::saveToXML(const char* c) {
0370   TXMLEngine* xml = new TXMLEngine();
0371 
0372   // Add the root node.
0373   XMLNodePointer_t root = xml->NewChild(nullptr, nullptr, rootNode->getName().c_str());
0374   addXMLAttributes(xml, rootNode.get(), root);
0375 
0376   // Recursively write the tree to XML.
0377   saveToXMLRecursive(xml, rootNode.get(), root);
0378 
0379   // Make the XML Document.
0380   XMLDocPointer_t xmldoc = xml->NewDoc();
0381   xml->DocSetRootElement(xmldoc, root);
0382 
0383   // Save to file.
0384   xml->SaveDoc(xmldoc, c);
0385 
0386   // Clean up.
0387   xml->FreeDoc(xmldoc);
0388   delete xml;
0389 }
0390 
0391 // ----------------------------------------------------------------------
0392 
0393 void Tree::saveToXMLRecursive(TXMLEngine* xml, Node* node, XMLNodePointer_t np) {
0394   Node* l = node->getLeftDaughter();
0395   Node* r = node->getRightDaughter();
0396 
0397   if (l == nullptr || r == nullptr)
0398     return;
0399 
0400   // Add children to the XMLEngine.
0401   XMLNodePointer_t left = xml->NewChild(np, nullptr, "left");
0402   XMLNodePointer_t right = xml->NewChild(np, nullptr, "right");
0403 
0404   // Add attributes to the children.
0405   addXMLAttributes(xml, l, left);
0406   addXMLAttributes(xml, r, right);
0407 
0408   // Recurse.
0409   saveToXMLRecursive(xml, l, left);
0410   saveToXMLRecursive(xml, r, right);
0411 }
0412 
0413 // ----------------------------------------------------------------------
0414 
0415 void Tree::loadFromXML(const char* filename) {
0416   // First create the engine.
0417   TXMLEngine* xml = new TXMLEngine;
0418 
0419   // Now try to parse xml file.
0420   XMLDocPointer_t xmldoc = xml->ParseFile(filename);
0421   if (xmldoc == nullptr) {
0422     delete xml;
0423     return;
0424   }
0425 
0426   // Get access to main node of the xml file.
0427   XMLNodePointer_t mainnode = xml->DocGetRootElement(xmldoc);
0428 
0429   // the original 2016 pT xmls define the source tree node to be the top-level xml node
0430   // while in 2017 TMVA's xmls every decision tree is wrapped in an extra block specifying boostWeight parameter
0431   // I choose to identify the format by checking the top xml node name that is a "BinaryTree" in 2017
0432   if (std::string("BinaryTree") == xml->GetNodeName(mainnode)) {
0433     XMLAttrPointer_t attr = xml->GetFirstAttr(mainnode);
0434     while (std::string("boostWeight") != xml->GetAttrName(attr)) {
0435       attr = xml->GetNextAttr(attr);
0436     }
0437     boostWeight = (attr ? strtod(xml->GetAttrValue(attr), nullptr) : 0);
0438     // step inside the top-level xml node
0439     mainnode = xml->GetChild(mainnode);
0440     xmlVersion = 2017;
0441   } else {
0442     boostWeight = 0;
0443     xmlVersion = 2016;
0444   }
0445   // Recursively connect nodes together.
0446   loadFromXMLRecursive(xml, mainnode, rootNode.get());
0447 
0448   // Release memory before exit
0449   xml->FreeDoc(xmldoc);
0450   delete xml;
0451 }
0452 
0453 // ----------------------------------------------------------------------
0454 
0455 void Tree::loadFromXMLRecursive(TXMLEngine* xml, XMLNodePointer_t xnode, Node* tnode) {
0456   // Get the split information from xml.
0457   XMLAttrPointer_t attr = xml->GetFirstAttr(xnode);
0458   std::vector<std::string> splitInfo(3);
0459   if (xmlVersion >= 2017) {
0460     for (unsigned int i = 0; i < 10; i++) {
0461       if (std::string("IVar") == xml->GetAttrName(attr)) {
0462         splitInfo[0] = xml->GetAttrValue(attr);
0463       }
0464       if (std::string("Cut") == xml->GetAttrName(attr)) {
0465         splitInfo[1] = xml->GetAttrValue(attr);
0466       }
0467       if (std::string("res") == xml->GetAttrName(attr)) {
0468         splitInfo[2] = xml->GetAttrValue(attr);
0469       }
0470       attr = xml->GetNextAttr(attr);
0471     }
0472   } else {
0473     for (unsigned int i = 0; i < 3; i++) {
0474       splitInfo[i] = xml->GetAttrValue(attr);
0475       attr = xml->GetNextAttr(attr);
0476     }
0477   }
0478 
0479   // Convert strings into numbers.
0480   std::stringstream converter;
0481   int splitVar;
0482   double splitVal;
0483   double fitVal;
0484 
0485   converter << splitInfo[0];
0486   converter >> splitVar;
0487   converter.str("");
0488   converter.clear();
0489 
0490   converter << splitInfo[1];
0491   converter >> splitVal;
0492   converter.str("");
0493   converter.clear();
0494 
0495   converter << splitInfo[2];
0496   converter >> fitVal;
0497   converter.str("");
0498   converter.clear();
0499 
0500   // Store gathered splitInfo into the node object.
0501   tnode->setSplitVariable(splitVar);
0502   tnode->setSplitValue(splitVal);
0503   tnode->setFitValue(fitVal);
0504 
0505   // Get the xml daughters of the current xml node.
0506   XMLNodePointer_t xleft = xml->GetChild(xnode);
0507   XMLNodePointer_t xright = xml->GetNext(xleft);
0508 
0509   // If there are no daughters we are done.
0510   if (xleft == nullptr || xright == nullptr)
0511     return;
0512 
0513   // If there are daughters link the node objects appropriately.
0514   tnode->theMiracleOfChildBirth();
0515   Node* tleft = tnode->getLeftDaughter();
0516   Node* tright = tnode->getRightDaughter();
0517 
0518   // Update the list of terminal nodes.
0519   terminalNodes.remove(tnode);
0520   terminalNodes.push_back(tleft);
0521   terminalNodes.push_back(tright);
0522   numTerminalNodes++;
0523 
0524   loadFromXMLRecursive(xml, xleft, tleft);
0525   loadFromXMLRecursive(xml, xright, tright);
0526 }
0527 
0528 void Tree::loadFromCondPayload(const L1TMuonEndCapForest::DTree& tree) {
0529   // start fresh in case this is not the only call to construct a tree
0530   rootNode = std::make_unique<Node>("root");
0531 
0532   const L1TMuonEndCapForest::DTreeNode& mainnode = tree[0];
0533   loadFromCondPayloadRecursive(tree, mainnode, rootNode.get());
0534 }
0535 
0536 void Tree::loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree& tree,
0537                                         const L1TMuonEndCapForest::DTreeNode& node,
0538                                         Node* tnode) {
0539   // Store gathered splitInfo into the node object.
0540   tnode->setSplitVariable(node.splitVar);
0541   tnode->setSplitValue(node.splitVal);
0542   tnode->setFitValue(node.fitVal);
0543 
0544   // If there are no daughters we are done.
0545   if (node.ileft == 0 || node.iright == 0)
0546     return;  // root cannot be anyone's child
0547   if (node.ileft >= tree.size() || node.iright >= tree.size())
0548     return;  // out of range addressing on purpose
0549 
0550   // If there are daughters link the node objects appropriately.
0551   tnode->theMiracleOfChildBirth();
0552   Node* tleft = tnode->getLeftDaughter();
0553   Node* tright = tnode->getRightDaughter();
0554 
0555   // Update the list of terminal nodes.
0556   terminalNodes.remove(tnode);
0557   terminalNodes.push_back(tleft);
0558   terminalNodes.push_back(tright);
0559   numTerminalNodes++;
0560 
0561   loadFromCondPayloadRecursive(tree, tree[node.ileft], tleft);
0562   loadFromCondPayloadRecursive(tree, tree[node.iright], tright);
0563 }