Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-03-23 23:40:23

0001 // Forest.h
0002 
0003 #ifndef L1Trigger_L1TMuonEndCap_emtf_Forest
0004 #define L1Trigger_L1TMuonEndCap_emtf_Forest
0005 
0006 #include <memory>
0007 #include "Tree.h"
0008 #include "LossFunctions.h"
0009 #include "CondFormats/L1TObjects/interface/L1TMuonEndCapForest.h"
0010 
0011 namespace emtf {
0012 
0013   class Forest {
0014   public:
0015     // Constructor(s)/Destructor
0016     Forest();
0017     Forest(std::vector<Event*>& trainingEvents);
0018     ~Forest() = default;
0019 
0020     Forest(const Forest& forest);
0021     Forest& operator=(const Forest& forest);
0022     Forest(Forest&& forest) = default;
0023     Forest& operator=(Forest&& forest) = default;
0024 
0025     // Get/Set
0026     void setTrainingEvents(std::vector<Event*>& trainingEvents);
0027     std::vector<Event*> getTrainingEvents();
0028 
0029     // Returns the number of trees in the forest.
0030     unsigned int size() const;
0031 
0032     // Get info on variable importance.
0033     void rankVariables(std::vector<int>& rank) const;
0034 
0035     // Output the list of split values used for each variable.
0036     void saveSplitValues(const char* savefilename) const;
0037 
0038     // Helpful operations
0039     void listEvents(std::vector<std::vector<Event*>>& e) const;
0040     void sortEventVectors(std::vector<std::vector<Event*>>& e) const;
0041     void generate(int numTrainEvents, int numTestEvents, double sigma);
0042     void loadForestFromXML(const char* directory, unsigned int numTrees);
0043     void loadFromCondPayload(const L1TMuonEndCapForest::DForest& payload);
0044 
0045     // Perform the regression
0046     void updateRegTargets(Tree* tree, double learningRate, LossFunction* l);
0047     void doRegression(int nodeLimit,
0048                       int treeLimit,
0049                       double learningRate,
0050                       LossFunction* l,
0051                       const char* savetreesdirectory,
0052                       bool saveTrees);
0053 
0054     // Stochastic Gradient Boosting
0055     void prepareRandomSubsample(double fraction);
0056     void doStochasticRegression(int nodeLimit, int treeLimit, double learningRate, double fraction, LossFunction* l);
0057 
0058     // Predict some events
0059     void updateEvents(Tree* tree);
0060     void appendCorrection(std::vector<Event*>& eventsp, int treenum);
0061     void predictEvents(std::vector<Event*>& eventsp, unsigned int trees);
0062     void appendCorrection(Event* e, int treenum);
0063     void predictEvent(Event* e, unsigned int trees);
0064 
0065     Tree* getTree(unsigned int i);
0066 
0067   private:
0068     std::vector<std::vector<Event*>> events;
0069     std::vector<std::vector<Event*>> subSample;
0070     std::vector<std::unique_ptr<Tree>> trees;
0071   };
0072 
0073 }  // namespace emtf
0074 
0075 #endif