File indexing completed on 2024-04-06 12:24:15
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016 #include <memory>
0017
0018
0019 #include "FWCore/Framework/interface/ModuleFactory.h"
0020 #include "FWCore/Framework/interface/ESProducer.h"
0021
0022 #include "FWCore/Framework/interface/ESHandle.h"
0023 #include "PhysicsTools/TensorFlow/interface/TfGraphRecord.h"
0024 #include "PhysicsTools/TensorFlow/interface/TfGraphDefWrapper.h"
0025
0026
0027
0028 class TfGraphDefProducer : public edm::ESProducer {
0029 public:
0030 TfGraphDefProducer(const edm::ParameterSet&);
0031 using ReturnType = std::unique_ptr<TfGraphDefWrapper>;
0032
0033 ReturnType produce(const TfGraphRecord&);
0034
0035 static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0036
0037 private:
0038 const std::string filename_;
0039
0040 };
0041
0042 TfGraphDefProducer::TfGraphDefProducer(const edm::ParameterSet& iConfig)
0043 : filename_(iConfig.getParameter<edm::FileInPath>("FileName").fullPath()) {
0044 auto componentName = iConfig.getParameter<std::string>("ComponentName");
0045 setWhatProduced(this, componentName);
0046 }
0047
0048
0049 TfGraphDefProducer::ReturnType TfGraphDefProducer::produce(const TfGraphRecord& iRecord) {
0050 auto* graph = tensorflow::loadGraphDef(filename_);
0051 return std::make_unique<TfGraphDefWrapper>(tensorflow::createSession(graph), graph);
0052 }
0053
0054 void TfGraphDefProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0055 edm::ParameterSetDescription desc;
0056 desc.add<std::string>("ComponentName", "tfGraphDef");
0057 desc.add<edm::FileInPath>("FileName", edm::FileInPath());
0058 descriptions.add("tfGraphDefProducer", desc);
0059 }
0060
0061
0062 #include "FWCore/PluginManager/interface/ModuleDef.h"
0063 #include "FWCore/Framework/interface/MakerMacros.h"
0064 DEFINE_FWK_EVENTSETUP_MODULE(TfGraphDefProducer);