GBRForestD

Macros

Line Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71

#ifndef EGAMMAOBJECTS_GBRForestD
#define EGAMMAOBJECTS_GBRForestD

//////////////////////////////////////////////////////////////////////////
//                                                                      //
// GBRForestD                                                           //
//                                                                      //
// A fast minimal implementation of Gradient-Boosted Regression Trees   //
// which has been especially optimized for size on disk and in memory.  //
//                                                                      //
// Designed to be built from the output of GBRLikelihood,               //
// but could also be  generalized to otherwise-trained trees            //
// classification, or other boosting methods in the future              //
//                                                                      //
//  Josh Bendavid - CERN                                                //
//////////////////////////////////////////////////////////////////////////

#include "CondFormats/Serialization/interface/Serializable.h"

#include "GBRTreeD.h"

#include <vector>

class GBRForestD {
public:
  typedef GBRTreeD TreeT;

  GBRForestD() {}
  template <typename InputForestT>
  GBRForestD(const InputForestT &forest);

  double GetResponse(const float *vector) const;

  double InitialResponse() const { return fInitialResponse; }
  void SetInitialResponse(double response) { fInitialResponse = response; }

  std::vector<GBRTreeD> &Trees() { return fTrees; }
  const std::vector<GBRTreeD> &Trees() const { return fTrees; }

protected:
  double fInitialResponse = 0.0;
  std::vector<GBRTreeD> fTrees;

  COND_SERIALIZABLE;
};

//_______________________________________________________________________
inline double GBRForestD::GetResponse(const float *vector) const {
  double response = fInitialResponse;
  for (std::vector<GBRTreeD>::const_iterator it = fTrees.begin(); it != fTrees.end(); ++it) {
    int termidx = it->TerminalIndex(vector);
    response += it->GetResponse(termidx);
  }
  return response;
}

//_______________________________________________________________________
template <typename InputForestT>
GBRForestD::GBRForestD(const InputForestT &forest) : fInitialResponse(forest.InitialResponse()) {
  //templated constructor to allow construction from Forest classes in GBRLikelihood
  //without creating an explicit dependency

  for (typename std::vector<typename InputForestT::TreeT>::const_iterator treeit = forest.Trees().begin();
       treeit != forest.Trees().end();
       ++treeit) {
    fTrees.push_back(GBRTreeD(*treeit));
  }
}

#endif