Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:15

0001 // -*- C++ -*-
0002 //
0003 // Package:    PhysicsTools/TensorFlow
0004 // Class:      TFGraphDefProducer
0005 //
0006 /**\class TFGraphDefProducer
0007  Description: Produces TfGraphRecord into the event containing a tensorflow GraphDef object that can be used for running inference on a pretrained network
0008 */
0009 //
0010 // Original Author:  Joona Havukainen
0011 //         Created:  Fri, 24 Jul 2020 08:04:00 GMT
0012 //
0013 //
0014 
0015 // system include files
0016 #include <memory>
0017 
0018 // user include files
0019 #include "FWCore/Framework/interface/ModuleFactory.h"
0020 #include "FWCore/Framework/interface/ESProducer.h"
0021 
0022 #include "FWCore/Framework/interface/ESHandle.h"
0023 #include "PhysicsTools/TensorFlow/interface/TfGraphRecord.h"
0024 #include "PhysicsTools/TensorFlow/interface/TfGraphDefWrapper.h"
0025 
0026 // class declaration
0027 
0028 class TfGraphDefProducer : public edm::ESProducer {
0029 public:
0030   TfGraphDefProducer(const edm::ParameterSet&);
0031   using ReturnType = std::unique_ptr<TfGraphDefWrapper>;
0032 
0033   ReturnType produce(const TfGraphRecord&);
0034 
0035   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0036 
0037 private:
0038   const std::string filename_;
0039   // ----------member data ---------------------------
0040 };
0041 
0042 TfGraphDefProducer::TfGraphDefProducer(const edm::ParameterSet& iConfig)
0043     : filename_(iConfig.getParameter<edm::FileInPath>("FileName").fullPath()) {
0044   auto componentName = iConfig.getParameter<std::string>("ComponentName");
0045   setWhatProduced(this, componentName);
0046 }
0047 
0048 // ------------ method called to produce the data  ------------
0049 TfGraphDefProducer::ReturnType TfGraphDefProducer::produce(const TfGraphRecord& iRecord) {
0050   auto* graph = tensorflow::loadGraphDef(filename_);
0051   return std::make_unique<TfGraphDefWrapper>(tensorflow::createSession(graph), graph);
0052 }
0053 
0054 void TfGraphDefProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0055   edm::ParameterSetDescription desc;
0056   desc.add<std::string>("ComponentName", "tfGraphDef");
0057   desc.add<edm::FileInPath>("FileName", edm::FileInPath());
0058   descriptions.add("tfGraphDefProducer", desc);
0059 }
0060 
0061 //define this as a plug-in
0062 #include "FWCore/PluginManager/interface/ModuleDef.h"
0063 #include "FWCore/Framework/interface/MakerMacros.h"
0064 DEFINE_FWK_EVENTSETUP_MODULE(TfGraphDefProducer);