Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:02:11

0001 #ifndef EGAMMAOBJECTS_GBRForest
0002 #define EGAMMAOBJECTS_GBRForest
0003 
0004 //////////////////////////////////////////////////////////////////////////
0005 //                                                                      //
0006 // GBRForest                                                            //
0007 //                                                                      //
0008 // A fast minimal implementation of Gradient-Boosted Regression Trees   //
0009 // which has been especially optimized for size on disk and in memory.  //
0010 //                                                                      //
0011 // Designed to be built from TMVA-trained trees, but could also be      //
0012 // generalized to otherwise-trained trees, classification,              //
0013 //  or other boosting methods in the future                             //
0014 //                                                                      //
0015 //  Josh Bendavid - MIT                                                 //
0016 //////////////////////////////////////////////////////////////////////////
0017 
0018 #include "CondFormats/Serialization/interface/Serializable.h"
0019 #include "CondFormats/GBRForest/interface/GBRTree.h"
0020 
0021 #include <cmath>
0022 #include <vector>
0023 
0024 class GBRForest {
0025 public:
0026   GBRForest() {}
0027 
0028   double GetResponse(const float* vector) const;
0029   double GetGradBoostClassifier(const float* vector) const;
0030   double GetAdaBoostClassifier(const float* vector) const { return GetResponse(vector); }
0031 
0032   //for backwards-compatibility
0033   double GetClassifier(const float* vector) const { return GetGradBoostClassifier(vector); }
0034 
0035   void SetInitialResponse(double response) { fInitialResponse = response; }
0036 
0037   std::vector<GBRTree>& Trees() { return fTrees; }
0038   const std::vector<GBRTree>& Trees() const { return fTrees; }
0039 
0040 protected:
0041   double fInitialResponse = 0.0;
0042   std::vector<GBRTree> fTrees;
0043 
0044   COND_SERIALIZABLE;
0045 };
0046 
0047 //_______________________________________________________________________
0048 inline double GBRForest::GetResponse(const float* vector) const {
0049   double response = fInitialResponse;
0050   for (auto const& tree : fTrees) {
0051     response += tree.GetResponse(vector);
0052   }
0053   return response;
0054 }
0055 
0056 //_______________________________________________________________________
0057 inline double GBRForest::GetGradBoostClassifier(const float* vector) const {
0058   double response = GetResponse(vector);
0059   return 2.0 / (1.0 + std::exp(-2.0 * response)) - 1;  //MVA output between -1 and 1
0060 }
0061 
0062 #endif