1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
|
#ifndef EGAMMAOBJECTS_GBRForest
#define EGAMMAOBJECTS_GBRForest
//////////////////////////////////////////////////////////////////////////
// //
// GBRForest //
// //
// A fast minimal implementation of Gradient-Boosted Regression Trees //
// which has been especially optimized for size on disk and in memory. //
// //
// Designed to be built from TMVA-trained trees, but could also be //
// generalized to otherwise-trained trees, classification, //
// or other boosting methods in the future //
// //
// Josh Bendavid - MIT //
//////////////////////////////////////////////////////////////////////////
#include "CondFormats/Serialization/interface/Serializable.h"
#include "CondFormats/GBRForest/interface/GBRTree.h"
#include <cmath>
#include <vector>
class GBRForest {
public:
GBRForest() {}
double GetResponse(const float* vector) const;
double GetGradBoostClassifier(const float* vector) const;
double GetAdaBoostClassifier(const float* vector) const { return GetResponse(vector); }
//for backwards-compatibility
double GetClassifier(const float* vector) const { return GetGradBoostClassifier(vector); }
void SetInitialResponse(double response) { fInitialResponse = response; }
std::vector<GBRTree>& Trees() { return fTrees; }
const std::vector<GBRTree>& Trees() const { return fTrees; }
protected:
double fInitialResponse = 0.0;
std::vector<GBRTree> fTrees;
COND_SERIALIZABLE;
};
//_______________________________________________________________________
inline double GBRForest::GetResponse(const float* vector) const {
double response = fInitialResponse;
for (auto const& tree : fTrees) {
response += tree.GetResponse(vector);
}
return response;
}
//_______________________________________________________________________
inline double GBRForest::GetGradBoostClassifier(const float* vector) const {
double response = GetResponse(vector);
return 2.0 / (1.0 + std::exp(-2.0 * response)) - 1; //MVA output between -1 and 1
}
#endif
|