Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-05-13 01:32:12

0001 //////////////////////////////////////////////////////////////////////////
0002 //                            Tree.cxx                                  //
0003 // =====================================================================//
0004 // This is the object implementation of a decision tree.                //
0005 // References include                                                   //
0006 //    *Elements of Statistical Learning by Hastie,                      //
0007 //     Tibshirani, and Friedman.                                        //
0008 //    *Greedy Function Approximation: A Gradient Boosting Machine.      //
0009 //     Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001.    //
0010 //    *Inductive Learning of Tree-based Regression Models. Luis Torgo.  //
0011 //                                                                      //
0012 //////////////////////////////////////////////////////////////////////////
0013 
0014 ///////////////////////////////////////////////////////////////////////////
0015 // _______________________Includes_______________________________________//
0016 ///////////////////////////////////////////////////////////////////////////
0017 
0018 #include "L1Trigger/L1TMuonEndCap/interface/bdt/Tree.h"
0019 
0020 #include <iostream>
0021 #include <sstream>
0022 #include <cmath>
0023 
0024 //////////////////////////////////////////////////////////////////////////
0025 // _______________________Constructor(s)________________________________//
0026 //////////////////////////////////////////////////////////////////////////
0027 
0028 using namespace emtf;
0029 
0030 Tree::Tree() {
0031   rootNode = new Node("root");
0032 
0033   terminalNodes.push_back(rootNode);
0034   numTerminalNodes = 1;
0035   boostWeight = 0;
0036   xmlVersion = 2017;
0037 }
0038 
0039 Tree::Tree(std::vector<std::vector<Event*>>& cEvents) {
0040   rootNode = new Node("root");
0041   rootNode->setEvents(cEvents);
0042 
0043   terminalNodes.push_back(rootNode);
0044   numTerminalNodes = 1;
0045   boostWeight = 0;
0046   xmlVersion = 2017;
0047 }
0048 //////////////////////////////////////////////////////////////////////////
0049 // _______________________Destructor____________________________________//
0050 //////////////////////////////////////////////////////////////////////////
0051 
0052 Tree::~Tree() {
0053   // When the tree is destroyed it will delete all of the nodes in the tree.
0054   // The deletion begins with the rootnode and continues recursively.
0055   if (rootNode)
0056     delete rootNode;
0057 }
0058 
0059 Tree::Tree(const Tree& tree) {
0060   // unfortunately, authors of these classes didn't use const qualifiers
0061   rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
0062   numTerminalNodes = tree.numTerminalNodes;
0063   rmsError = tree.rmsError;
0064   boostWeight = tree.boostWeight;
0065   xmlVersion = tree.xmlVersion;
0066 
0067   terminalNodes.resize(0);
0068   // find new leafs
0069   findLeafs(rootNode, terminalNodes);
0070 
0071   ///    if( numTerminalNodes != terminalNodes.size() ) throw std::runtime_error();
0072 }
0073 
0074 Tree& Tree::operator=(const Tree& tree) {
0075   if (rootNode)
0076     delete rootNode;
0077   // unfortunately, authors of these classes didn't use const qualifiers
0078   rootNode = copyFrom(const_cast<Tree&>(tree).getRootNode());
0079   numTerminalNodes = tree.numTerminalNodes;
0080   rmsError = tree.rmsError;
0081   boostWeight = tree.boostWeight;
0082   xmlVersion = tree.xmlVersion;
0083 
0084   terminalNodes.resize(0);
0085   // find new leafs
0086   findLeafs(rootNode, terminalNodes);
0087 
0088   ///    if( numTerminalNodes != terminalNodes.size() ) throw std::runtime_error();
0089 
0090   return *this;
0091 }
0092 
0093 Node* Tree::copyFrom(const Node* local_root) {
0094   // end-case
0095   if (!local_root)
0096     return nullptr;
0097 
0098   Node* lr = const_cast<Node*>(local_root);
0099 
0100   // recursion
0101   Node* left_new_child = copyFrom(lr->getLeftDaughter());
0102   Node* right_new_child = copyFrom(lr->getRightDaughter());
0103 
0104   // performing main work at this level
0105   Node* new_local_root = new Node(lr->getName());
0106   if (left_new_child)
0107     left_new_child->setParent(new_local_root);
0108   if (right_new_child)
0109     right_new_child->setParent(new_local_root);
0110   new_local_root->setLeftDaughter(left_new_child);
0111   new_local_root->setRightDaughter(right_new_child);
0112   new_local_root->setErrorReduction(lr->getErrorReduction());
0113   new_local_root->setSplitValue(lr->getSplitValue());
0114   new_local_root->setSplitVariable(lr->getSplitVariable());
0115   new_local_root->setFitValue(lr->getFitValue());
0116   new_local_root->setTotalError(lr->getTotalError());
0117   new_local_root->setAvgError(lr->getAvgError());
0118   new_local_root->setNumEvents(lr->getNumEvents());
0119   //    new_local_root->setEvents( lr->getEvents() ); // no ownership assumed for the events anyways
0120 
0121   return new_local_root;
0122 }
0123 
0124 void Tree::findLeafs(Node* local_root, std::list<Node*>& tn) {
0125   if (!local_root->getLeftDaughter() && !local_root->getRightDaughter()) {
0126     // leaf or ternimal node found
0127     tn.push_back(local_root);
0128     return;
0129   }
0130 
0131   if (local_root->getLeftDaughter())
0132     findLeafs(local_root->getLeftDaughter(), tn);
0133 
0134   if (local_root->getRightDaughter())
0135     findLeafs(local_root->getRightDaughter(), tn);
0136 }
0137 
0138 Tree::Tree(Tree&& tree) {
0139   if (rootNode)
0140     delete rootNode;  // this line is the only reason not to use default move constructor
0141   rootNode = tree.rootNode;
0142   terminalNodes = std::move(tree.terminalNodes);
0143   numTerminalNodes = tree.numTerminalNodes;
0144   rmsError = tree.rmsError;
0145   boostWeight = tree.boostWeight;
0146   xmlVersion = tree.xmlVersion;
0147 }
0148 
0149 //////////////////////////////////////////////////////////////////////////
0150 // ______________________Get/Set________________________________________//
0151 //////////////////////////////////////////////////////////////////////////
0152 
0153 void Tree::setRootNode(Node* sRootNode) { rootNode = sRootNode; }
0154 
0155 Node* Tree::getRootNode() { return rootNode; }
0156 
0157 // ----------------------------------------------------------------------
0158 
0159 void Tree::setTerminalNodes(std::list<Node*>& sTNodes) { terminalNodes = sTNodes; }
0160 
0161 std::list<Node*>& Tree::getTerminalNodes() { return terminalNodes; }
0162 
0163 // ----------------------------------------------------------------------
0164 
0165 int Tree::getNumTerminalNodes() { return numTerminalNodes; }
0166 
0167 //////////////////////////////////////////////////////////////////////////
0168 // ______________________Performace_____________________________________//
0169 //////////////////////////////////////////////////////////////////////////
0170 
0171 void Tree::calcError() {
0172   // Loop through the separate predictive regions (terminal nodes) and
0173   // add up the errors to get the error of the entire space.
0174 
0175   double totalSquaredError = 0;
0176 
0177   for (std::list<Node*>::iterator it = terminalNodes.begin(); it != terminalNodes.end(); it++) {
0178     totalSquaredError += (*it)->getTotalError();
0179   }
0180   rmsError = sqrt(totalSquaredError / rootNode->getNumEvents());
0181 }
0182 
0183 // ----------------------------------------------------------------------
0184 
0185 void Tree::buildTree(int nodeLimit) {
0186   // We greedily pick the best terminal node to split.
0187   double bestNodeErrorReduction = -1;
0188   Node* nodeToSplit = nullptr;
0189 
0190   if (numTerminalNodes == 1) {
0191     rootNode->calcOptimumSplit();
0192     calcError();
0193     //        std::cout << std::endl << "  " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
0194   }
0195 
0196   for (std::list<Node*>::iterator it = terminalNodes.begin(); it != terminalNodes.end(); it++) {
0197     if ((*it)->getErrorReduction() > bestNodeErrorReduction) {
0198       bestNodeErrorReduction = (*it)->getErrorReduction();
0199       nodeToSplit = (*it);
0200     }
0201   }
0202 
0203   //std::cout << "nodeToSplit size = " << nodeToSplit->getNumEvents() << std::endl;
0204 
0205   // If all of the nodes have one event we can't add any more nodes and reduce the error.
0206   if (nodeToSplit == nullptr)
0207     return;
0208 
0209   // Create daughter nodes, and link the nodes together appropriately.
0210   nodeToSplit->theMiracleOfChildBirth();
0211 
0212   // Get left and right daughters for reference.
0213   Node* left = nodeToSplit->getLeftDaughter();
0214   Node* right = nodeToSplit->getRightDaughter();
0215 
0216   // Update the list of terminal nodes.
0217   terminalNodes.remove(nodeToSplit);
0218   terminalNodes.push_back(left);
0219   terminalNodes.push_back(right);
0220   numTerminalNodes++;
0221 
0222   // Filter the events from the parent into the daughters.
0223   nodeToSplit->filterEventsToDaughters();
0224 
0225   // Calculate the best splits for the new nodes.
0226   left->calcOptimumSplit();
0227   right->calcOptimumSplit();
0228 
0229   // See if the error reduces as we add more nodes.
0230   calcError();
0231 
0232   if (numTerminalNodes % 1 == 0) {
0233     //        std::cout << "  " << numTerminalNodes << " Nodes : " << rmsError << std::endl;
0234   }
0235 
0236   // Repeat until done.
0237   if (numTerminalNodes < nodeLimit)
0238     buildTree(nodeLimit);
0239 }
0240 
0241 // ----------------------------------------------------------------------
0242 
0243 void Tree::filterEvents(std::vector<Event*>& tEvents) {
0244   // Use trees which have already been built to fit a bunch of events
0245   // given by the tEvents vector.
0246 
0247   // Set the events to be filtered.
0248   rootNode->getEvents() = std::vector<std::vector<Event*>>(1);
0249   rootNode->getEvents()[0] = tEvents;
0250 
0251   // The tree now knows about the events it needs to fit.
0252   // Filter them into a predictive region (terminal node).
0253   filterEventsRecursive(rootNode);
0254 }
0255 
0256 // ----------------------------------------------------------------------
0257 
0258 void Tree::filterEventsRecursive(Node* node) {
0259   // Filter the events repeatedly into the daughter nodes until they
0260   // fall into a terminal node.
0261 
0262   Node* left = node->getLeftDaughter();
0263   Node* right = node->getRightDaughter();
0264 
0265   if (left == nullptr || right == nullptr)
0266     return;
0267 
0268   node->filterEventsToDaughters();
0269 
0270   filterEventsRecursive(left);
0271   filterEventsRecursive(right);
0272 }
0273 
0274 // ----------------------------------------------------------------------
0275 
0276 Node* Tree::filterEvent(Event* e) {
0277   // Use trees which have already been built to fit a bunch of events
0278   // given by the tEvents vector.
0279 
0280   // Filter the event into a predictive region (terminal node).
0281   Node* node = filterEventRecursive(rootNode, e);
0282   return node;
0283 }
0284 
0285 // ----------------------------------------------------------------------
0286 
0287 Node* Tree::filterEventRecursive(Node* node, Event* e) {
0288   // Filter the event repeatedly into the daughter nodes until it
0289   // falls into a terminal node.
0290 
0291   Node* nextNode = node->filterEventToDaughter(e);
0292   if (nextNode == nullptr)
0293     return node;
0294 
0295   return filterEventRecursive(nextNode, e);
0296 }
0297 
0298 // ----------------------------------------------------------------------
0299 
0300 void Tree::rankVariablesRecursive(Node* node, std::vector<double>& v) {
0301   // We recursively go through all of the nodes in the tree and find the
0302   // total error reduction for each variable. The one with the most
0303   // error reduction should be the most important.
0304 
0305   Node* left = node->getLeftDaughter();
0306   Node* right = node->getRightDaughter();
0307 
0308   // Terminal nodes don't contribute to error reduction.
0309   if (left == nullptr || right == nullptr)
0310     return;
0311 
0312   int sv = node->getSplitVariable();
0313   double er = node->getErrorReduction();
0314 
0315   //if(sv == -1)
0316   //{
0317   //std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
0318   //std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
0319   //std::cout << "rankVarRecursive Error Reduction = " << er << std::endl;
0320   //}
0321 
0322   // Add error reduction to the current total for the appropriate
0323   // variable.
0324   v[sv] += er;
0325 
0326   rankVariablesRecursive(left, v);
0327   rankVariablesRecursive(right, v);
0328 }
0329 
0330 // ----------------------------------------------------------------------
0331 
0332 void Tree::rankVariables(std::vector<double>& v) { rankVariablesRecursive(rootNode, v); }
0333 
0334 // ----------------------------------------------------------------------
0335 
0336 void Tree::getSplitValuesRecursive(Node* node, std::vector<std::vector<double>>& v) {
0337   // We recursively go through all of the nodes in the tree and find the
0338   // split points used for each split variable.
0339 
0340   Node* left = node->getLeftDaughter();
0341   Node* right = node->getRightDaughter();
0342 
0343   // Terminal nodes don't contribute.
0344   if (left == nullptr || right == nullptr)
0345     return;
0346 
0347   int sv = node->getSplitVariable();
0348   double sp = node->getSplitValue();
0349 
0350   if (sv == -1) {
0351     std::cout << "ERROR: negative split variable for nonterminal node." << std::endl;
0352     std::cout << "rankVarRecursive Split Variable = " << sv << std::endl;
0353   }
0354 
0355   // Add the split point to the list for the correct split variable.
0356   v[sv].push_back(sp);
0357 
0358   getSplitValuesRecursive(left, v);
0359   getSplitValuesRecursive(right, v);
0360 }
0361 
0362 // ----------------------------------------------------------------------
0363 
0364 void Tree::getSplitValues(std::vector<std::vector<double>>& v) { getSplitValuesRecursive(rootNode, v); }
0365 
0366 //////////////////////////////////////////////////////////////////////////
0367 // ______________________Storage/Retrieval______________________________//
0368 //////////////////////////////////////////////////////////////////////////
0369 
0370 template <typename T>
0371 std::string numToStr(T num) {
0372   // Convert a number to a string.
0373   std::stringstream ss;
0374   ss << num;
0375   std::string s = ss.str();
0376   return s;
0377 }
0378 
0379 // ----------------------------------------------------------------------
0380 
0381 void Tree::addXMLAttributes(TXMLEngine* xml, Node* node, XMLNodePointer_t np) {
0382   // Convert Node members into XML attributes
0383   // and add them to the XMLEngine.
0384   xml->NewAttr(np, nullptr, "splitVar", numToStr(node->getSplitVariable()).c_str());
0385   xml->NewAttr(np, nullptr, "splitVal", numToStr(node->getSplitValue()).c_str());
0386   xml->NewAttr(np, nullptr, "fitVal", numToStr(node->getFitValue()).c_str());
0387 }
0388 
0389 // ----------------------------------------------------------------------
0390 
0391 void Tree::saveToXML(const char* c) {
0392   TXMLEngine* xml = new TXMLEngine();
0393 
0394   // Add the root node.
0395   XMLNodePointer_t root = xml->NewChild(nullptr, nullptr, rootNode->getName().c_str());
0396   addXMLAttributes(xml, rootNode, root);
0397 
0398   // Recursively write the tree to XML.
0399   saveToXMLRecursive(xml, rootNode, root);
0400 
0401   // Make the XML Document.
0402   XMLDocPointer_t xmldoc = xml->NewDoc();
0403   xml->DocSetRootElement(xmldoc, root);
0404 
0405   // Save to file.
0406   xml->SaveDoc(xmldoc, c);
0407 
0408   // Clean up.
0409   xml->FreeDoc(xmldoc);
0410   delete xml;
0411 }
0412 
0413 // ----------------------------------------------------------------------
0414 
0415 void Tree::saveToXMLRecursive(TXMLEngine* xml, Node* node, XMLNodePointer_t np) {
0416   Node* l = node->getLeftDaughter();
0417   Node* r = node->getRightDaughter();
0418 
0419   if (l == nullptr || r == nullptr)
0420     return;
0421 
0422   // Add children to the XMLEngine.
0423   XMLNodePointer_t left = xml->NewChild(np, nullptr, "left");
0424   XMLNodePointer_t right = xml->NewChild(np, nullptr, "right");
0425 
0426   // Add attributes to the children.
0427   addXMLAttributes(xml, l, left);
0428   addXMLAttributes(xml, r, right);
0429 
0430   // Recurse.
0431   saveToXMLRecursive(xml, l, left);
0432   saveToXMLRecursive(xml, r, right);
0433 }
0434 
0435 // ----------------------------------------------------------------------
0436 
0437 void Tree::loadFromXML(const char* filename) {
0438   // First create the engine.
0439   TXMLEngine* xml = new TXMLEngine;
0440 
0441   // Now try to parse xml file.
0442   XMLDocPointer_t xmldoc = xml->ParseFile(filename);
0443   if (xmldoc == nullptr) {
0444     delete xml;
0445     return;
0446   }
0447 
0448   // Get access to main node of the xml file.
0449   XMLNodePointer_t mainnode = xml->DocGetRootElement(xmldoc);
0450 
0451   // the original 2016 pT xmls define the source tree node to be the top-level xml node
0452   // while in 2017 TMVA's xmls every decision tree is wrapped in an extra block specifying boostWeight parameter
0453   // I choose to identify the format by checking the top xml node name that is a "BinaryTree" in 2017
0454   if (std::string("BinaryTree") == xml->GetNodeName(mainnode)) {
0455     XMLAttrPointer_t attr = xml->GetFirstAttr(mainnode);
0456     while (std::string("boostWeight") != xml->GetAttrName(attr)) {
0457       attr = xml->GetNextAttr(attr);
0458     }
0459     boostWeight = (attr ? strtod(xml->GetAttrValue(attr), nullptr) : 0);
0460     // step inside the top-level xml node
0461     mainnode = xml->GetChild(mainnode);
0462     xmlVersion = 2017;
0463   } else {
0464     boostWeight = 0;
0465     xmlVersion = 2016;
0466   }
0467   // Recursively connect nodes together.
0468   loadFromXMLRecursive(xml, mainnode, rootNode);
0469 
0470   // Release memory before exit
0471   xml->FreeDoc(xmldoc);
0472   delete xml;
0473 }
0474 
0475 // ----------------------------------------------------------------------
0476 
0477 void Tree::loadFromXMLRecursive(TXMLEngine* xml, XMLNodePointer_t xnode, Node* tnode) {
0478   // Get the split information from xml.
0479   XMLAttrPointer_t attr = xml->GetFirstAttr(xnode);
0480   std::vector<std::string> splitInfo(3);
0481   if (xmlVersion >= 2017) {
0482     for (unsigned int i = 0; i < 10; i++) {
0483       if (std::string("IVar") == xml->GetAttrName(attr)) {
0484         splitInfo[0] = xml->GetAttrValue(attr);
0485       }
0486       if (std::string("Cut") == xml->GetAttrName(attr)) {
0487         splitInfo[1] = xml->GetAttrValue(attr);
0488       }
0489       if (std::string("res") == xml->GetAttrName(attr)) {
0490         splitInfo[2] = xml->GetAttrValue(attr);
0491       }
0492       attr = xml->GetNextAttr(attr);
0493     }
0494   } else {
0495     for (unsigned int i = 0; i < 3; i++) {
0496       splitInfo[i] = xml->GetAttrValue(attr);
0497       attr = xml->GetNextAttr(attr);
0498     }
0499   }
0500 
0501   // Convert strings into numbers.
0502   std::stringstream converter;
0503   int splitVar;
0504   double splitVal;
0505   double fitVal;
0506 
0507   converter << splitInfo[0];
0508   converter >> splitVar;
0509   converter.str("");
0510   converter.clear();
0511 
0512   converter << splitInfo[1];
0513   converter >> splitVal;
0514   converter.str("");
0515   converter.clear();
0516 
0517   converter << splitInfo[2];
0518   converter >> fitVal;
0519   converter.str("");
0520   converter.clear();
0521 
0522   // Store gathered splitInfo into the node object.
0523   tnode->setSplitVariable(splitVar);
0524   tnode->setSplitValue(splitVal);
0525   tnode->setFitValue(fitVal);
0526 
0527   // Get the xml daughters of the current xml node.
0528   XMLNodePointer_t xleft = xml->GetChild(xnode);
0529   XMLNodePointer_t xright = xml->GetNext(xleft);
0530 
0531   // If there are no daughters we are done.
0532   if (xleft == nullptr || xright == nullptr)
0533     return;
0534 
0535   // If there are daughters link the node objects appropriately.
0536   tnode->theMiracleOfChildBirth();
0537   Node* tleft = tnode->getLeftDaughter();
0538   Node* tright = tnode->getRightDaughter();
0539 
0540   // Update the list of terminal nodes.
0541   terminalNodes.remove(tnode);
0542   terminalNodes.push_back(tleft);
0543   terminalNodes.push_back(tright);
0544   numTerminalNodes++;
0545 
0546   loadFromXMLRecursive(xml, xleft, tleft);
0547   loadFromXMLRecursive(xml, xright, tright);
0548 }
0549 
0550 void Tree::loadFromCondPayload(const L1TMuonEndCapForest::DTree& tree) {
0551   // start fresh in case this is not the only call to construct a tree
0552   if (rootNode)
0553     delete rootNode;
0554   rootNode = new Node("root");
0555 
0556   const L1TMuonEndCapForest::DTreeNode& mainnode = tree[0];
0557   loadFromCondPayloadRecursive(tree, mainnode, rootNode);
0558 }
0559 
0560 void Tree::loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree& tree,
0561                                         const L1TMuonEndCapForest::DTreeNode& node,
0562                                         Node* tnode) {
0563   // Store gathered splitInfo into the node object.
0564   tnode->setSplitVariable(node.splitVar);
0565   tnode->setSplitValue(node.splitVal);
0566   tnode->setFitValue(node.fitVal);
0567 
0568   // If there are no daughters we are done.
0569   if (node.ileft == 0 || node.iright == 0)
0570     return;  // root cannot be anyone's child
0571   if (node.ileft >= tree.size() || node.iright >= tree.size())
0572     return;  // out of range addressing on purpose
0573 
0574   // If there are daughters link the node objects appropriately.
0575   tnode->theMiracleOfChildBirth();
0576   Node* tleft = tnode->getLeftDaughter();
0577   Node* tright = tnode->getRightDaughter();
0578 
0579   // Update the list of terminal nodes.
0580   terminalNodes.remove(tnode);
0581   terminalNodes.push_back(tleft);
0582   terminalNodes.push_back(tright);
0583   numTerminalNodes++;
0584 
0585   loadFromCondPayloadRecursive(tree, tree[node.ileft], tleft);
0586   loadFromCondPayloadRecursive(tree, tree[node.iright], tright);
0587 }