Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:49

0001 #include "LowPtGsfElectronSeedHeavyObjectCache.h"
0002 #include "CommonTools/MVAUtils/interface/GBRForestTools.h"
0003 #include "DataFormats/BeamSpot/interface/BeamSpot.h"
0004 #include "DataFormats/ParticleFlowReco/interface/PFCluster.h"
0005 #include "DataFormats/ParticleFlowReco/interface/PFClusterFwd.h"
0006 #include "DataFormats/ParticleFlowReco/interface/PreId.h"
0007 #include "DataFormats/TrackReco/interface/Track.h"
0008 #include "DataFormats/TrackReco/interface/TrackFwd.h"
0009 #include "FWCore/ParameterSet/interface/FileInPath.h"
0010 #include "RecoEgamma/EgammaElectronProducers/interface/LowPtGsfElectronFeatures.h"
0011 
0012 #include <string>
0013 
0014 namespace lowptgsfeleseed {
0015 
0016   ////////////////////////////////////////////////////////////////////////////////
0017   //
0018   HeavyObjectCache::HeavyObjectCache(const edm::ParameterSet& conf) {
0019     for (auto& name : conf.getParameter<std::vector<std::string> >("ModelNames")) {
0020       names_.push_back(name);
0021     }
0022     for (auto& weights : conf.getParameter<std::vector<std::string> >("ModelWeights")) {
0023       models_.push_back(createGBRForest(edm::FileInPath(weights)));
0024     }
0025     for (auto& thresh : conf.getParameter<std::vector<double> >("ModelThresholds")) {
0026       thresholds_.push_back(thresh);
0027     }
0028     if (names_.size() != models_.size()) {
0029       throw cms::Exception("Incorrect configuration")
0030           << "'ModelNames' size (" << names_.size() << ") != 'ModelWeights' size (" << models_.size() << ").\n";
0031     }
0032     if (models_.size() != thresholds_.size()) {
0033       throw cms::Exception("Incorrect configuration")
0034           << "'ModelWeights' size (" << models_.size() << ") != 'ModelThresholds' size (" << thresholds_.size()
0035           << ").\n";
0036     }
0037   }
0038 
0039   ////////////////////////////////////////////////////////////////////////////////
0040   //
0041   bool HeavyObjectCache::eval(const std::string& name,
0042                               reco::PreId& ecal,
0043                               reco::PreId& hcal,
0044                               double rho,
0045                               const reco::BeamSpot& spot,
0046                               noZS::EcalClusterLazyTools& ecalTools) const {
0047     std::vector<std::string>::const_iterator iter = std::find(names_.begin(), names_.end(), name);
0048     if (iter != names_.end()) {
0049       int index = std::distance(names_.begin(), iter);
0050       std::vector<float> inputs = features(ecal, hcal, rho, spot, ecalTools);
0051       float output = models_.at(index)->GetResponse(inputs.data());
0052       bool pass = output > thresholds_.at(index);
0053       ecal.setMVA(pass, output, index);
0054       return pass;
0055     } else {
0056       throw cms::Exception("Unknown model name")
0057           << "'Name given: '" << name << "'. Check against configuration file.\n";
0058     }
0059   }
0060 
0061 }  // namespace lowptgsfeleseed