GBRForest2D

Macros

Line Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60

#ifndef EGAMMAOBJECTS_GBRForest2D
#define EGAMMAOBJECTS_GBRForest2D

//////////////////////////////////////////////////////////////////////////
//                                                                      //
// GBRForest2D                                                            //
//                                                                      //
// A fast minimal implementation of Gradient-Boosted Regression Trees   //
// which has been especially optimized for size on disk and in memory.  //
//                                                                      //
// Designed to be built from TMVA-trained trees, but could also be      //
// generalized to otherwise-trained trees, classification,              //
//  or other boosting methods in the future                             //
//                                                                      //
//  Josh Bendavid - MIT                                                 //
//////////////////////////////////////////////////////////////////////////

#include "CondFormats/Serialization/interface/Serializable.h"

#include "GBRTree2D.h"

#include <vector>

class GBRForest2D {
public:
  GBRForest2D() {}

  void GetResponse(const float *vector, double &x, double &y) const;

  void SetInitialResponse(double x, double y) {
    fInitialResponseX = x;
    fInitialResponseY = y;
  }

  std::vector<GBRTree2D> &Trees() { return fTrees; }
  const std::vector<GBRTree2D> &Trees() const { return fTrees; }

protected:
  double fInitialResponseX = 0.0;
  double fInitialResponseY = 0.0;
  std::vector<GBRTree2D> fTrees;

  COND_SERIALIZABLE;
};

//_______________________________________________________________________
inline void GBRForest2D::GetResponse(const float *vector, double &x, double &y) const {
  x = fInitialResponseX;
  y = fInitialResponseY;
  double tx, ty;
  for (std::vector<GBRTree2D>::const_iterator it = fTrees.begin(); it != fTrees.end(); ++it) {
    it->GetResponse(vector, tx, ty);
    x += tx;
    y += ty;
  }
  return;
}

#endif