Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:02:11

0001 
0002 #ifndef EGAMMAOBJECTS_GBRTreeD
0003 #define EGAMMAOBJECTS_GBRTreeD
0004 
0005 //////////////////////////////////////////////////////////////////////////
0006 //                                                                      //
0007 // GBRForest                                                            //
0008 //                                                                      //
0009 // A fast minimal implementation of Gradient-Boosted Regression Trees   //
0010 // which has been especially optimized for size on disk and in memory.  //
0011 //                                                                      //
0012 // Designed to be built from TMVA-trained trees, but could also be      //
0013 // generalized to otherwise-trained trees, classification,              //
0014 //  or other boosting methods in the future                             //
0015 //                                                                      //
0016 //  Josh Bendavid - CERN                                                 //
0017 //////////////////////////////////////////////////////////////////////////
0018 
0019 // The decision tree is implemented here as a set of two arrays, one for
0020 // intermediate nodes, containing the variable index and cut value, as well
0021 // as the indices of the 'left' and 'right' daughter nodes.  Positive indices
0022 // indicate further intermediate nodes, whereas negative indices indicate
0023 // terminal nodes, which are stored simply as a vector of regression responses
0024 
0025 #include "CondFormats/Serialization/interface/Serializable.h"
0026 
0027 #include <vector>
0028 
0029 class GBRTreeD {
0030 public:
0031   GBRTreeD() {}
0032   template <typename InputTreeT>
0033   GBRTreeD(const InputTreeT &tree);
0034 
0035   //double GetResponse(const float* vector) const;
0036   double GetResponse(int termidx) const { return fResponses[termidx]; }
0037   int TerminalIndex(const float *vector) const;
0038 
0039   std::vector<double> &Responses() { return fResponses; }
0040   const std::vector<double> &Responses() const { return fResponses; }
0041 
0042   std::vector<unsigned short> &CutIndices() { return fCutIndices; }
0043   const std::vector<unsigned short> &CutIndices() const { return fCutIndices; }
0044 
0045   std::vector<float> &CutVals() { return fCutVals; }
0046   const std::vector<float> &CutVals() const { return fCutVals; }
0047 
0048   std::vector<int> &LeftIndices() { return fLeftIndices; }
0049   const std::vector<int> &LeftIndices() const { return fLeftIndices; }
0050 
0051   std::vector<int> &RightIndices() { return fRightIndices; }
0052   const std::vector<int> &RightIndices() const { return fRightIndices; }
0053 
0054 protected:
0055   std::vector<unsigned short> fCutIndices;
0056   std::vector<float> fCutVals;
0057   std::vector<int> fLeftIndices;
0058   std::vector<int> fRightIndices;
0059   std::vector<double> fResponses;
0060 
0061   COND_SERIALIZABLE;
0062 };
0063 
0064 //_______________________________________________________________________
0065 inline int GBRTreeD::TerminalIndex(const float *vector) const {
0066   int index = 0;
0067 
0068   unsigned short cutindex = fCutIndices[0];
0069   float cutval = fCutVals[0];
0070 
0071   while (true) {
0072     if (vector[cutindex] > cutval) {
0073       index = fRightIndices[index];
0074     } else {
0075       index = fLeftIndices[index];
0076     }
0077 
0078     if (index > 0) {
0079       cutindex = fCutIndices[index];
0080       cutval = fCutVals[index];
0081     } else {
0082       return (-index);
0083     }
0084   }
0085 }
0086 
0087 //_______________________________________________________________________
0088 template <typename InputTreeT>
0089 GBRTreeD::GBRTreeD(const InputTreeT &tree)
0090     : fCutIndices(tree.CutIndices()),
0091       fCutVals(tree.CutVals()),
0092       fLeftIndices(tree.LeftIndices()),
0093       fRightIndices(tree.RightIndices()),
0094       fResponses(tree.Responses()) {}
0095 
0096 #endif