Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-05-04 04:04:14

0001 
0002 #ifndef EGAMMAOBJECTS_GBRTree
0003 #define EGAMMAOBJECTS_GBRTree
0004 
0005 //////////////////////////////////////////////////////////////////////////
0006 //                                                                      //
0007 // GBRForest                                                            //
0008 //                                                                      //
0009 // A fast minimal implementation of Gradient-Boosted Regression Trees   //
0010 // which has been especially optimized for size on disk and in memory.  //
0011 //                                                                      //
0012 // Designed to be built from TMVA-trained trees, but could also be      //
0013 // generalized to otherwise-trained trees, classification,              //
0014 //  or other boosting methods in the future                             //
0015 //                                                                      //
0016 //  Josh Bendavid - MIT                                                 //
0017 //////////////////////////////////////////////////////////////////////////
0018 
0019 // The decision tree is implemented here as a set of two arrays, one for
0020 // intermediate nodes, containing the variable index and cut value, as well
0021 // as the indices of the 'left' and 'right' daughter nodes.  Positive indices
0022 // indicate further intermediate nodes, whereas negative indices indicate
0023 // terminal nodes, which are stored simply as a vector of regression responses
0024 
0025 #include "CondFormats/Serialization/interface/Serializable.h"
0026 
0027 #include <vector>
0028 
0029 class GBRTree {
0030 public:
0031   GBRTree() {}
0032   explicit GBRTree(int nIntermediate, int nTerminal);
0033 
0034   double GetResponse(const float *vector) const;
0035 
0036   std::vector<float> &Responses() { return fResponses; }
0037   const std::vector<float> &Responses() const { return fResponses; }
0038 
0039   std::vector<unsigned char> &CutIndices() { return fCutIndices; }
0040   const std::vector<unsigned char> &CutIndices() const { return fCutIndices; }
0041 
0042   std::vector<float> &CutVals() { return fCutVals; }
0043   const std::vector<float> &CutVals() const { return fCutVals; }
0044 
0045   std::vector<int> &LeftIndices() { return fLeftIndices; }
0046   const std::vector<int> &LeftIndices() const { return fLeftIndices; }
0047 
0048   std::vector<int> &RightIndices() { return fRightIndices; }
0049   const std::vector<int> &RightIndices() const { return fRightIndices; }
0050 
0051 protected:
0052   std::vector<unsigned char> fCutIndices;
0053   std::vector<float> fCutVals;
0054   std::vector<int> fLeftIndices;
0055   std::vector<int> fRightIndices;
0056   std::vector<float> fResponses;
0057 
0058   COND_SERIALIZABLE;
0059 };
0060 
0061 //_______________________________________________________________________
0062 inline double GBRTree::GetResponse(const float *vector) const {
0063   int index = 0;
0064   do {
0065     auto r = fRightIndices[index];
0066     auto l = fLeftIndices[index];
0067     unsigned int x = vector[fCutIndices[index]] > fCutVals[index] ? ~0 : 0;
0068     index = (x & r) | ((~x) & l);
0069   } while (index > 0);
0070   return fResponses[-index];
0071 }
0072 
0073 #endif