File indexing completed on 2024-04-06 12:02:11
0001 #ifndef EGAMMAOBJECTS_GBRForest
0002 #define EGAMMAOBJECTS_GBRForest
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 #include "CondFormats/Serialization/interface/Serializable.h"
0019 #include "CondFormats/GBRForest/interface/GBRTree.h"
0020
0021 #include <cmath>
0022 #include <vector>
0023
0024 class GBRForest {
0025 public:
0026 GBRForest() {}
0027
0028 double GetResponse(const float* vector) const;
0029 double GetGradBoostClassifier(const float* vector) const;
0030 double GetAdaBoostClassifier(const float* vector) const { return GetResponse(vector); }
0031
0032
0033 double GetClassifier(const float* vector) const { return GetGradBoostClassifier(vector); }
0034
0035 void SetInitialResponse(double response) { fInitialResponse = response; }
0036
0037 std::vector<GBRTree>& Trees() { return fTrees; }
0038 const std::vector<GBRTree>& Trees() const { return fTrees; }
0039
0040 protected:
0041 double fInitialResponse = 0.0;
0042 std::vector<GBRTree> fTrees;
0043
0044 COND_SERIALIZABLE;
0045 };
0046
0047
0048 inline double GBRForest::GetResponse(const float* vector) const {
0049 double response = fInitialResponse;
0050 for (auto const& tree : fTrees) {
0051 response += tree.GetResponse(vector);
0052 }
0053 return response;
0054 }
0055
0056
0057 inline double GBRForest::GetGradBoostClassifier(const float* vector) const {
0058 double response = GetResponse(vector);
0059 return 2.0 / (1.0 + std::exp(-2.0 * response)) - 1;
0060 }
0061
0062 #endif