Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:20:54

0001 // Forest.h
0002 
0003 #ifndef L1Trigger_L1TMuonEndCap_emtf_Forest
0004 #define L1Trigger_L1TMuonEndCap_emtf_Forest
0005 
0006 #include "Tree.h"
0007 #include "LossFunctions.h"
0008 #include "CondFormats/L1TObjects/interface/L1TMuonEndCapForest.h"
0009 
0010 namespace emtf {
0011 
0012   class Forest {
0013   public:
0014     // Constructor(s)/Destructor
0015     Forest();
0016     Forest(std::vector<Event*>& trainingEvents);
0017     ~Forest();
0018 
0019     Forest(const Forest& forest);
0020     Forest& operator=(const Forest& forest);
0021     Forest(Forest&& forest) = default;
0022 
0023     // Get/Set
0024     void setTrainingEvents(std::vector<Event*>& trainingEvents);
0025     std::vector<Event*> getTrainingEvents();
0026 
0027     // Returns the number of trees in the forest.
0028     unsigned int size();
0029 
0030     // Get info on variable importance.
0031     void rankVariables(std::vector<int>& rank);
0032 
0033     // Output the list of split values used for each variable.
0034     void saveSplitValues(const char* savefilename);
0035 
0036     // Helpful operations
0037     void listEvents(std::vector<std::vector<Event*> >& e);
0038     void sortEventVectors(std::vector<std::vector<Event*> >& e);
0039     void generate(int numTrainEvents, int numTestEvents, double sigma);
0040     void loadForestFromXML(const char* directory, unsigned int numTrees);
0041     void loadFromCondPayload(const L1TMuonEndCapForest::DForest& payload);
0042 
0043     // Perform the regression
0044     void updateRegTargets(Tree* tree, double learningRate, LossFunction* l);
0045     void doRegression(int nodeLimit,
0046                       int treeLimit,
0047                       double learningRate,
0048                       LossFunction* l,
0049                       const char* savetreesdirectory,
0050                       bool saveTrees);
0051 
0052     // Stochastic Gradient Boosting
0053     void prepareRandomSubsample(double fraction);
0054     void doStochasticRegression(int nodeLimit, int treeLimit, double learningRate, double fraction, LossFunction* l);
0055 
0056     // Predict some events
0057     void updateEvents(Tree* tree);
0058     void appendCorrection(std::vector<Event*>& eventsp, int treenum);
0059     void predictEvents(std::vector<Event*>& eventsp, unsigned int trees);
0060     void appendCorrection(Event* e, int treenum);
0061     void predictEvent(Event* e, unsigned int trees);
0062 
0063     Tree* getTree(unsigned int i);
0064 
0065   private:
0066     std::vector<std::vector<Event*> > events;
0067     std::vector<std::vector<Event*> > subSample;
0068     std::vector<Tree*> trees;
0069   };
0070 
0071 }  // namespace emtf
0072 
0073 #endif