GBRForest

Macros

Line Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
#ifndef EGAMMAOBJECTS_GBRForest
#define EGAMMAOBJECTS_GBRForest

//////////////////////////////////////////////////////////////////////////
//                                                                      //
// GBRForest                                                            //
//                                                                      //
// A fast minimal implementation of Gradient-Boosted Regression Trees   //
// which has been especially optimized for size on disk and in memory.  //
//                                                                      //
// Designed to be built from TMVA-trained trees, but could also be      //
// generalized to otherwise-trained trees, classification,              //
//  or other boosting methods in the future                             //
//                                                                      //
//  Josh Bendavid - MIT                                                 //
//////////////////////////////////////////////////////////////////////////

#include "CondFormats/Serialization/interface/Serializable.h"
#include "CondFormats/GBRForest/interface/GBRTree.h"

#include <cmath>
#include <vector>

class GBRForest {
public:
  GBRForest() {}

  double GetResponse(const float* vector) const;
  double GetGradBoostClassifier(const float* vector) const;
  double GetAdaBoostClassifier(const float* vector) const { return GetResponse(vector); }

  //for backwards-compatibility
  double GetClassifier(const float* vector) const { return GetGradBoostClassifier(vector); }

  void SetInitialResponse(double response) { fInitialResponse = response; }

  std::vector<GBRTree>& Trees() { return fTrees; }
  const std::vector<GBRTree>& Trees() const { return fTrees; }

protected:
  double fInitialResponse = 0.0;
  std::vector<GBRTree> fTrees;

  COND_SERIALIZABLE;
};

//_______________________________________________________________________
inline double GBRForest::GetResponse(const float* vector) const {
  double response = fInitialResponse;
  for (auto const& tree : fTrees) {
    response += tree.GetResponse(vector);
  }
  return response;
}

//_______________________________________________________________________
inline double GBRForest::GetGradBoostClassifier(const float* vector) const {
  double response = GetResponse(vector);
  return 2.0 / (1.0 + std::exp(-2.0 * response)) - 1;  //MVA output between -1 and 1
}

#endif