Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:02:11

0001 
0002 #ifndef EGAMMAOBJECTS_GBRTree2D
0003 #define EGAMMAOBJECTS_GBRTree2D
0004 
0005 //////////////////////////////////////////////////////////////////////////
0006 //                                                                      //
0007 // GBRForest                                                            //
0008 //                                                                      //
0009 // A fast minimal implementation of Gradient-Boosted Regression Trees   //
0010 // which has been especially optimized for size on disk and in memory.  //
0011 //                                                                      //
0012 // Designed to be built from TMVA-trained trees, but could also be      //
0013 // generalized to otherwise-trained trees, classification,              //
0014 //  or other boosting methods in the future                             //
0015 //                                                                      //
0016 //  Josh Bendavid - MIT                                                 //
0017 //////////////////////////////////////////////////////////////////////////
0018 
0019 // The decision tree is implemented here as a set of two arrays, one for
0020 // intermediate nodes, containing the variable index and cut value, as well
0021 // as the indices of the 'left' and 'right' daughter nodes.  Positive indices
0022 // indicate further intermediate nodes, whereas negative indices indicate
0023 // terminal nodes, which are stored simply as a vector of regression responses
0024 
0025 #include "CondFormats/Serialization/interface/Serializable.h"
0026 
0027 #include <vector>
0028 
0029 class GBRTree2D {
0030 public:
0031   GBRTree2D() {}
0032 
0033   void GetResponse(const float *vector, double &x, double &y) const;
0034   int TerminalIndex(const float *vector) const;
0035 
0036   std::vector<float> &ResponsesX() { return fResponsesX; }
0037   const std::vector<float> &ResponsesX() const { return fResponsesX; }
0038 
0039   std::vector<float> &ResponsesY() { return fResponsesY; }
0040   const std::vector<float> &ResponsesY() const { return fResponsesY; }
0041 
0042   std::vector<unsigned short> &CutIndices() { return fCutIndices; }
0043   const std::vector<unsigned short> &CutIndices() const { return fCutIndices; }
0044 
0045   std::vector<float> &CutVals() { return fCutVals; }
0046   const std::vector<float> &CutVals() const { return fCutVals; }
0047 
0048   std::vector<int> &LeftIndices() { return fLeftIndices; }
0049   const std::vector<int> &LeftIndices() const { return fLeftIndices; }
0050 
0051   std::vector<int> &RightIndices() { return fRightIndices; }
0052   const std::vector<int> &RightIndices() const { return fRightIndices; }
0053 
0054 protected:
0055   std::vector<unsigned short> fCutIndices;
0056   std::vector<float> fCutVals;
0057   std::vector<int> fLeftIndices;
0058   std::vector<int> fRightIndices;
0059   std::vector<float> fResponsesX;
0060   std::vector<float> fResponsesY;
0061 
0062   COND_SERIALIZABLE;
0063 };
0064 
0065 //_______________________________________________________________________
0066 inline void GBRTree2D::GetResponse(const float *vector, double &x, double &y) const {
0067   int index = 0;
0068 
0069   unsigned short cutindex = fCutIndices[0];
0070   float cutval = fCutVals[0];
0071 
0072   while (true) {
0073     if (vector[cutindex] > cutval) {
0074       index = fRightIndices[index];
0075     } else {
0076       index = fLeftIndices[index];
0077     }
0078 
0079     if (index > 0) {
0080       cutindex = fCutIndices[index];
0081       cutval = fCutVals[index];
0082     } else {
0083       x = fResponsesX[-index];
0084       y = fResponsesY[-index];
0085       return;
0086     }
0087   }
0088 }
0089 
0090 //_______________________________________________________________________
0091 inline int GBRTree2D::TerminalIndex(const float *vector) const {
0092   int index = 0;
0093 
0094   unsigned short cutindex = fCutIndices[0];
0095   float cutval = fCutVals[0];
0096 
0097   while (true) {
0098     if (vector[cutindex] > cutval) {
0099       index = fRightIndices[index];
0100     } else {
0101       index = fLeftIndices[index];
0102     }
0103 
0104     if (index > 0) {
0105       cutindex = fCutIndices[index];
0106       cutval = fCutVals[index];
0107     } else {
0108       return (-index);
0109     }
0110   }
0111 }
0112 
0113 #endif