Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-02-14 13:33:24

0001 /*
0002  * TensorFlow interface helpers.
0003  * Based on TensorFlow C++ API 2.1.
0004  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0005  *
0006  * Author: Marcel Rieger
0007  */
0008 
0009 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0010 
0011 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0012 
0013 namespace tensorflow {
0014 
0015   void setLogging(const std::string& level) { setenv("TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0); }
0016 
0017   void setThreading(SessionOptions& sessionOptions, int nThreads) {
0018     // set number of threads used for intra and inter operation communication
0019     sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
0020     sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
0021   }
0022 
0023   void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool) {
0024     edm::LogInfo("PhysicsTools/TensorFlow") << "setting the thread pool via tensorflow::setThreading() is deprecated";
0025 
0026     setThreading(sessionOptions, nThreads);
0027   }
0028 
0029   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
0030     // objects to load the graph
0031     Status status;
0032     RunOptions runOptions;
0033     SavedModelBundle bundle;
0034 
0035     // load the model
0036     status = LoadSavedModel(sessionOptions, runOptions, exportDir, {tag}, &bundle);
0037     if (!status.ok()) {
0038       throw cms::Exception("InvalidMetaGraphDef")
0039           << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
0040     }
0041 
0042     // return a copy of the graph
0043     return new MetaGraphDef(bundle.meta_graph_def);
0044   }
0045 
0046   MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
0047     edm::LogInfo("PhysicsTools/TensorFlow")
0048         << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
0049 
0050     return loadMetaGraphDef(exportDir, tag, sessionOptions);
0051   }
0052 
0053   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, int nThreads) {
0054     // create session options and set thread options
0055     SessionOptions sessionOptions;
0056     setThreading(sessionOptions, nThreads);
0057 
0058     return loadMetaGraphDef(exportDir, tag, sessionOptions);
0059   }
0060 
0061   MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, int nThreads) {
0062     edm::LogInfo("PhysicsTools/TensorFlow")
0063         << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
0064 
0065     return loadMetaGraphDef(exportDir, tag, nThreads);
0066   }
0067 
0068   GraphDef* loadGraphDef(const std::string& pbFile) {
0069     // objects to load the graph
0070     Status status;
0071 
0072     // load it
0073     GraphDef* graphDef = new GraphDef();
0074     status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
0075 
0076     // check for success
0077     if (!status.ok()) {
0078       throw cms::Exception("InvalidGraphDef")
0079           << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
0080     }
0081 
0082     return graphDef;
0083   }
0084 
0085   Session* createSession(SessionOptions& sessionOptions) {
0086     // objects to create the session
0087     Status status;
0088 
0089     // create a new, empty session
0090     Session* session = nullptr;
0091     status = NewSession(sessionOptions, &session);
0092     if (!status.ok()) {
0093       throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
0094     }
0095 
0096     return session;
0097   }
0098 
0099   Session* createSession(int nThreads) {
0100     // create session options and set thread options
0101     SessionOptions sessionOptions;
0102     setThreading(sessionOptions, nThreads);
0103 
0104     return createSession(sessionOptions);
0105   }
0106 
0107   Session* createSession(const MetaGraphDef* metaGraphDef,
0108                          const std::string& exportDir,
0109                          SessionOptions& sessionOptions) {
0110     // check for valid pointer
0111     if (metaGraphDef == nullptr) {
0112       throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
0113     }
0114 
0115     // check that the graph has nodes
0116     if (metaGraphDef->graph_def().node_size() <= 0) {
0117       throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
0118     }
0119 
0120     Session* session = createSession(sessionOptions);
0121 
0122     // add the graph def from the meta graph
0123     Status status;
0124     status = session->Create(metaGraphDef->graph_def());
0125     if (!status.ok()) {
0126       throw cms::Exception("InvalidMetaGraphDef")
0127           << "error while attaching metaGraphDef to session: " << status.ToString();
0128     }
0129 
0130     // restore variables using the variable and index files in the export directory
0131     // first, find names and paths
0132     std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
0133     std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
0134     std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
0135     std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
0136     std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
0137 
0138     // when the index file is missing, there's nothing to do
0139     if (!Env::Default()->FileExists(indexFile).ok()) {
0140       return session;
0141     }
0142 
0143     // create a tensor to store the variable file
0144     Tensor varFileTensor(DT_STRING, TensorShape({}));
0145     varFileTensor.scalar<tensorflow::tstring>()() = varFile;
0146 
0147     // run the restore op
0148     status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
0149     if (!status.ok()) {
0150       throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
0151     }
0152 
0153     return session;
0154   }
0155 
0156   Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads) {
0157     // create session options and set thread options
0158     SessionOptions sessionOptions;
0159     setThreading(sessionOptions, nThreads);
0160 
0161     return createSession(metaGraphDef, exportDir, sessionOptions);
0162   }
0163 
0164   Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions) {
0165     // check for valid pointer
0166     if (graphDef == nullptr) {
0167       throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
0168     }
0169 
0170     // check that the graph has nodes
0171     if (graphDef->node_size() <= 0) {
0172       throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
0173     }
0174 
0175     // create a new, empty session
0176     Session* session = createSession(sessionOptions);
0177 
0178     // add the graph def
0179     Status status;
0180     status = session->Create(*graphDef);
0181 
0182     // check for success
0183     if (!status.ok()) {
0184       throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
0185     }
0186 
0187     return session;
0188   }
0189 
0190   Session* createSession(const GraphDef* graphDef, int nThreads) {
0191     // create session options and set thread options
0192     SessionOptions sessionOptions;
0193     setThreading(sessionOptions, nThreads);
0194 
0195     return createSession(graphDef, sessionOptions);
0196   }
0197 
0198   bool closeSession(Session*& session) {
0199     if (session == nullptr) {
0200       return true;
0201     }
0202 
0203     // close and delete the session
0204     Status status = session->Close();
0205     delete session;
0206 
0207     // reset the pointer
0208     session = nullptr;
0209 
0210     return status.ok();
0211   }
0212 
0213   void run(Session* session,
0214            const NamedTensorList& inputs,
0215            const std::vector<std::string>& outputNames,
0216            std::vector<Tensor>* outputs,
0217            const thread::ThreadPoolOptions& threadPoolOptions) {
0218     if (session == nullptr) {
0219       throw cms::Exception("InvalidSession") << "cannot run empty session";
0220     }
0221 
0222     // create empty run options
0223     RunOptions runOptions;
0224 
0225     // run and check the status
0226     Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
0227     if (!status.ok()) {
0228       throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
0229     }
0230   }
0231 
0232   void run(Session* session,
0233            const NamedTensorList& inputs,
0234            const std::vector<std::string>& outputNames,
0235            std::vector<Tensor>* outputs,
0236            thread::ThreadPoolInterface* threadPool) {
0237     // create thread pool options
0238     thread::ThreadPoolOptions threadPoolOptions;
0239     threadPoolOptions.inter_op_threadpool = threadPool;
0240     threadPoolOptions.intra_op_threadpool = threadPool;
0241 
0242     // run
0243     run(session, inputs, outputNames, outputs, threadPoolOptions);
0244   }
0245 
0246   void run(Session* session,
0247            const NamedTensorList& inputs,
0248            const std::vector<std::string>& outputNames,
0249            std::vector<Tensor>* outputs,
0250            const std::string& threadPoolName) {
0251     // lookup the thread pool and forward the call accordingly
0252     if (threadPoolName == "no_threads") {
0253       run(session, inputs, outputNames, outputs, &NoThreadPool::instance());
0254     } else if (threadPoolName == "tbb") {
0255       // the TBBTreadPool singleton should be already initialized before with a number of threads
0256       run(session, inputs, outputNames, outputs, &TBBThreadPool::instance());
0257     } else if (threadPoolName == "tensorflow") {
0258       run(session, inputs, outputNames, outputs, nullptr);
0259     } else {
0260       throw cms::Exception("UnknownThreadPool")
0261           << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
0262     }
0263   }
0264 
0265   void run(Session* session,
0266            const std::vector<std::string>& outputNames,
0267            std::vector<Tensor>* outputs,
0268            const std::string& threadPoolName) {
0269     run(session, {}, outputNames, outputs, threadPoolName);
0270   }
0271 
0272 }  // namespace tensorflow