File indexing completed on 2024-04-06 12:24:49
0001 #include "LowPtGsfElectronSeedHeavyObjectCache.h"
0002 #include "CommonTools/MVAUtils/interface/GBRForestTools.h"
0003 #include "DataFormats/BeamSpot/interface/BeamSpot.h"
0004 #include "DataFormats/ParticleFlowReco/interface/PFCluster.h"
0005 #include "DataFormats/ParticleFlowReco/interface/PFClusterFwd.h"
0006 #include "DataFormats/ParticleFlowReco/interface/PreId.h"
0007 #include "DataFormats/TrackReco/interface/Track.h"
0008 #include "DataFormats/TrackReco/interface/TrackFwd.h"
0009 #include "FWCore/ParameterSet/interface/FileInPath.h"
0010 #include "RecoEgamma/EgammaElectronProducers/interface/LowPtGsfElectronFeatures.h"
0011
0012 #include <string>
0013
0014 namespace lowptgsfeleseed {
0015
0016
0017
0018 HeavyObjectCache::HeavyObjectCache(const edm::ParameterSet& conf) {
0019 for (auto& name : conf.getParameter<std::vector<std::string> >("ModelNames")) {
0020 names_.push_back(name);
0021 }
0022 for (auto& weights : conf.getParameter<std::vector<std::string> >("ModelWeights")) {
0023 models_.push_back(createGBRForest(edm::FileInPath(weights)));
0024 }
0025 for (auto& thresh : conf.getParameter<std::vector<double> >("ModelThresholds")) {
0026 thresholds_.push_back(thresh);
0027 }
0028 if (names_.size() != models_.size()) {
0029 throw cms::Exception("Incorrect configuration")
0030 << "'ModelNames' size (" << names_.size() << ") != 'ModelWeights' size (" << models_.size() << ").\n";
0031 }
0032 if (models_.size() != thresholds_.size()) {
0033 throw cms::Exception("Incorrect configuration")
0034 << "'ModelWeights' size (" << models_.size() << ") != 'ModelThresholds' size (" << thresholds_.size()
0035 << ").\n";
0036 }
0037 }
0038
0039
0040
0041 bool HeavyObjectCache::eval(const std::string& name,
0042 reco::PreId& ecal,
0043 reco::PreId& hcal,
0044 double rho,
0045 const reco::BeamSpot& spot,
0046 noZS::EcalClusterLazyTools& ecalTools) const {
0047 std::vector<std::string>::const_iterator iter = std::find(names_.begin(), names_.end(), name);
0048 if (iter != names_.end()) {
0049 int index = std::distance(names_.begin(), iter);
0050 std::vector<float> inputs = features(ecal, hcal, rho, spot, ecalTools);
0051 float output = models_.at(index)->GetResponse(inputs.data());
0052 bool pass = output > thresholds_.at(index);
0053 ecal.setMVA(pass, output, index);
0054 return pass;
0055 } else {
0056 throw cms::Exception("Unknown model name")
0057 << "'Name given: '" << name << "'. Check against configuration file.\n";
0058 }
0059 }
0060
0061 }