Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-02-14 14:23:20

0001 //////////////////////////////////////////////////////////////////////////
0002 //                            Node.cxx                                  //
0003 // =====================================================================//
0004 // This is the object implementation of a node, which is the            //
0005 // fundamental unit of a decision tree.                                 //
0006 // References include                                                   //
0007 //    *Elements of Statistical Learning by Hastie,                      //
0008 //     Tibshirani, and Friedman.                                        //
0009 //    *Greedy Function Approximation: A Gradient Boosting Machine.      //
0010 //     Friedman. The Annals of Statistics, Vol. 29, No. 5. Oct 2001.    //
0011 //    *Inductive Learning of Tree-based Regression Models. Luis Torgo.  //
0012 //                                                                      //
0013 //////////////////////////////////////////////////////////////////////////
0014 
0015 ///////////////////////////////////////////////////////////////////////////
0016 // _______________________Includes_______________________________________//
0017 ///////////////////////////////////////////////////////////////////////////
0018 
0019 #include "L1Trigger/L1TMuonEndCap/interface/bdt/Node.h"
0020 
0021 #include "TRandom3.h"
0022 #include "TStopwatch.h"
0023 #include <iostream>
0024 #include <fstream>
0025 
0026 //////////////////////////////////////////////////////////////////////////
0027 // _______________________Constructor(s)________________________________//
0028 //////////////////////////////////////////////////////////////////////////
0029 
0030 using namespace emtf;
0031 
0032 Node::Node() {
0033   name = "";
0034   leftDaughter = nullptr;
0035   rightDaughter = nullptr;
0036   parent = nullptr;
0037   splitValue = -99;
0038   splitVariable = -1;
0039   avgError = -1;
0040   totalError = -1;
0041   errorReduction = -1;
0042 }
0043 
0044 Node::Node(std::string cName) {
0045   name = cName;
0046   leftDaughter = nullptr;
0047   rightDaughter = nullptr;
0048   parent = nullptr;
0049   splitValue = -99;
0050   splitVariable = -1;
0051   avgError = -1;
0052   totalError = -1;
0053   errorReduction = -1;
0054 }
0055 
0056 //////////////////////////////////////////////////////////////////////////
0057 // _______________________Destructor____________________________________//
0058 //////////////////////////////////////////////////////////////////////////
0059 
0060 Node::~Node() {
0061   // Recursively delete all nodes in the tree.
0062   if (leftDaughter)
0063     delete leftDaughter;
0064   if (rightDaughter)
0065     delete rightDaughter;
0066 }
0067 
0068 //////////////////////////////////////////////////////////////////////////
0069 // ______________________Get/Set________________________________________//
0070 //////////////////////////////////////////////////////////////////////////
0071 
0072 void Node::setName(std::string sName) { name = sName; }
0073 
0074 std::string Node::getName() { return name; }
0075 
0076 // ----------------------------------------------------------------------
0077 
0078 void Node::setErrorReduction(double sErrorReduction) { errorReduction = sErrorReduction; }
0079 
0080 double Node::getErrorReduction() { return errorReduction; }
0081 
0082 // ----------------------------------------------------------------------
0083 
0084 void Node::setLeftDaughter(Node* sLeftDaughter) { leftDaughter = sLeftDaughter; }
0085 
0086 Node* Node::getLeftDaughter() { return leftDaughter; }
0087 
0088 void Node::setRightDaughter(Node* sRightDaughter) { rightDaughter = sRightDaughter; }
0089 
0090 Node* Node::getRightDaughter() { return rightDaughter; }
0091 
0092 // ----------------------------------------------------------------------
0093 
0094 void Node::setParent(Node* sParent) { parent = sParent; }
0095 
0096 Node* Node::getParent() { return parent; }
0097 
0098 // ----------------------------------------------------------------------
0099 
0100 void Node::setSplitValue(double sSplitValue) { splitValue = sSplitValue; }
0101 
0102 double Node::getSplitValue() { return splitValue; }
0103 
0104 void Node::setSplitVariable(int sSplitVar) { splitVariable = sSplitVar; }
0105 
0106 int Node::getSplitVariable() { return splitVariable; }
0107 
0108 // ----------------------------------------------------------------------
0109 
0110 void Node::setFitValue(double sFitValue) { fitValue = sFitValue; }
0111 
0112 double Node::getFitValue() { return fitValue; }
0113 
0114 // ----------------------------------------------------------------------
0115 
0116 void Node::setTotalError(double sTotalError) { totalError = sTotalError; }
0117 
0118 double Node::getTotalError() { return totalError; }
0119 
0120 void Node::setAvgError(double sAvgError) { avgError = sAvgError; }
0121 
0122 double Node::getAvgError() { return avgError; }
0123 
0124 // ----------------------------------------------------------------------
0125 
0126 void Node::setNumEvents(int sNumEvents) { numEvents = sNumEvents; }
0127 
0128 int Node::getNumEvents() { return numEvents; }
0129 
0130 // ----------------------------------------------------------------------
0131 
0132 std::vector<std::vector<Event*> >& Node::getEvents() { return events; }
0133 
0134 void Node::setEvents(std::vector<std::vector<Event*> >& sEvents) {
0135   events = sEvents;
0136   numEvents = events[0].size();
0137 }
0138 
0139 ///////////////////////////////////////////////////////////////////////////
0140 // ______________________Performace_Functions___________________________//
0141 //////////////////////////////////////////////////////////////////////////
0142 
0143 void Node::calcOptimumSplit() {
0144   // Determines the split variable and split point which would most reduce the error for the given node (region).
0145   // In the process we calculate the fitValue and Error. The general aglorithm is based upon  Luis Torgo's thesis.
0146   // Check out the reference for a more in depth outline. This part is chapter 3.
0147 
0148   // Intialize some variables.
0149   double bestSplitValue = 0;
0150   int bestSplitVariable = -1;
0151   double bestErrorReduction = -1;
0152 
0153   double SUM = 0;
0154   double SSUM = 0;
0155   numEvents = events[0].size();
0156 
0157   double candidateErrorReduction = -1;
0158 
0159   // Calculate the sum of the target variables and the sum of
0160   // the target variables squared. We use these later.
0161   for (unsigned int i = 0; i < events[0].size(); i++) {
0162     double target = events[0][i]->data[0];
0163     SUM += target;
0164     SSUM += target * target;
0165   }
0166 
0167   unsigned int numVars = events.size();
0168 
0169   // Calculate the best split point for each variable
0170   for (unsigned int variableToCheck = 1; variableToCheck < numVars; variableToCheck++) {
0171     // The sum of the target variables in the left, right nodes
0172     double SUMleft = 0;
0173     double SUMright = SUM;
0174 
0175     // The number of events in the left, right nodes
0176     int nleft = 1;
0177     int nright = events[variableToCheck].size() - 1;
0178 
0179     int candidateSplitVariable = variableToCheck;
0180 
0181     std::vector<Event*>& v = events[variableToCheck];
0182 
0183     // Find the best split point for this variable
0184     for (unsigned int i = 1; i < v.size(); i++) {
0185       // As the candidate split point interates, the number of events in the
0186       // left/right node increases/decreases and SUMleft/right increases/decreases.
0187 
0188       SUMleft = SUMleft + v[i - 1]->data[0];
0189       SUMright = SUMright - v[i - 1]->data[0];
0190 
0191       // No need to check the split point if x on both sides is equal
0192       if (v[i - 1]->data[candidateSplitVariable] < v[i]->data[candidateSplitVariable]) {
0193         // Finding the maximum error reduction for Least Squares boils down to maximizing
0194         // the following statement.
0195         candidateErrorReduction = SUMleft * SUMleft / nleft + SUMright * SUMright / nright - SUM * SUM / numEvents;
0196         //                std::cout << "candidateErrorReduction= " << candidateErrorReduction << std::endl << std::endl;
0197 
0198         // if the new candidate is better than the current best, then we have a new overall best.
0199         if (candidateErrorReduction > bestErrorReduction) {
0200           bestErrorReduction = candidateErrorReduction;
0201           bestSplitValue = (v[i - 1]->data[candidateSplitVariable] + v[i]->data[candidateSplitVariable]) / 2;
0202           bestSplitVariable = candidateSplitVariable;
0203         }
0204       }
0205 
0206       nright = nright - 1;
0207       nleft = nleft + 1;
0208     }
0209   }
0210 
0211   // Store the information gained from our computations.
0212 
0213   // The fit value is the average for least squares.
0214   fitValue = SUM / numEvents;
0215   //    std::cout << "fitValue= " << fitValue << std::endl;
0216 
0217   // n*[ <y^2>-k^2 ]
0218   totalError = SSUM - SUM * SUM / numEvents;
0219   //    std::cout << "totalError= " << totalError << std::endl;
0220 
0221   // [ <y^2>-k^2 ]
0222   avgError = totalError / numEvents;
0223   //    std::cout << "avgError= " << avgError << std::endl;
0224 
0225   errorReduction = bestErrorReduction;
0226   //    std::cout << "errorReduction= " << errorReduction << std::endl;
0227 
0228   splitVariable = bestSplitVariable;
0229   //    std::cout << "splitVariable= " << splitVariable << std::endl;
0230 
0231   splitValue = bestSplitValue;
0232   //    std::cout << "splitValue= " << splitValue << std::endl;
0233 
0234   //if(bestSplitVariable == -1) std::cout << "splitVar = -1. numEvents = " << numEvents << ". errRed = " << errorReduction << std::endl;
0235 }
0236 
0237 // ----------------------------------------------------------------------
0238 
0239 void Node::listEvents() {
0240   std::cout << std::endl << "Listing Events... " << std::endl;
0241 
0242   for (unsigned int i = 0; i < events.size(); i++) {
0243     std::cout << std::endl << "Variable " << i << " vector contents: " << std::endl;
0244     for (unsigned int j = 0; j < events[i].size(); j++) {
0245       events[i][j]->outputEvent();
0246     }
0247     std::cout << std::endl;
0248   }
0249 }
0250 
0251 // ----------------------------------------------------------------------
0252 
0253 void Node::theMiracleOfChildBirth() {
0254   // Create Daughter Nodes
0255   Node* left = new Node(name + " left");
0256   Node* right = new Node(name + " right");
0257 
0258   // Link the Nodes Appropriately
0259   leftDaughter = left;
0260   rightDaughter = right;
0261   left->setParent(this);
0262   right->setParent(this);
0263 }
0264 
0265 // ----------------------------------------------------------------------
0266 
0267 void Node::filterEventsToDaughters() {
0268   // Keeping sorted copies of the event vectors allows us to save on
0269   // computation time. That way we don't have to resort the events
0270   // each time we calculate the splitpoint for a node. We sort them once.
0271   // Every time we split a node, we simply filter them down correctly
0272   // preserving the order. This way we have O(n) efficiency instead
0273   // of O(nlogn) efficiency.
0274 
0275   // Anyways, this function takes events from the parent node
0276   // and filters an event into the left or right daughter
0277   // node depending on whether it is < or > the split point
0278   // for the given split variable.
0279 
0280   unsigned int sv = splitVariable;
0281   double sp = splitValue;
0282 
0283   Node* left = leftDaughter;
0284   Node* right = rightDaughter;
0285 
0286   std::vector<std::vector<Event*> > l(events.size());
0287   std::vector<std::vector<Event*> > r(events.size());
0288 
0289   for (unsigned int i = 0; i < events.size(); i++) {
0290     for (unsigned int j = 0; j < events[i].size(); j++) {
0291       Event* e = events[i][j];
0292       // Prevent out-of-bounds access
0293       if (sv >= e->data.size())
0294         continue;
0295       if (e->data[sv] < sp)
0296         l[i].push_back(e);
0297       if (e->data[sv] > sp)
0298         r[i].push_back(e);
0299     }
0300   }
0301 
0302   events = std::vector<std::vector<Event*> >();
0303 
0304   left->getEvents().swap(l);
0305   right->getEvents().swap(r);
0306 
0307   // Set the number of events in the node.
0308   left->setNumEvents(left->getEvents()[0].size());
0309   right->setNumEvents(right->getEvents()[0].size());
0310 }
0311 
0312 // ----------------------------------------------------------------------
0313 
0314 Node* Node::filterEventToDaughter(Event* e) {
0315   // Anyways, this function takes an event from the parent node
0316   // and filters an event into the left or right daughter
0317   // node depending on whether it is < or > the split point
0318   // for the given split variable.
0319 
0320   unsigned int sv = splitVariable;
0321   double sp = splitValue;
0322 
0323   Node* left = leftDaughter;
0324   Node* right = rightDaughter;
0325   Node* nextNode = nullptr;
0326 
0327   // Prevent out-of-bounds access
0328   if (left == nullptr || right == nullptr || sv >= e->data.size())
0329     return nullptr;
0330 
0331   if (e->data[sv] < sp)
0332     nextNode = left;
0333   if (e->data[sv] >= sp)
0334     nextNode = right;
0335 
0336   return nextNode;
0337 }