Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-12-13 23:50:21

0001 /*
0002  * TensorFlow interface helpers.
0003  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0004  *
0005  * Author: Marcel Rieger
0006  */
0007 
0008 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0009 
0010 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0011 
0012 namespace tensorflow {
0013 
0014   void setLogging(const std::string& level) { setenv("TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0); }
0015 
0016   void setThreading(SessionOptions& sessionOptions, int nThreads) {
0017     // set number of threads used for intra and inter operation communication
0018     sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
0019     sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
0020   }
0021 
0022   void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool) {
0023     edm::LogInfo("PhysicsTools/TensorFlow") << "setting the thread pool via tensorflow::setThreading() is deprecated";
0024 
0025     setThreading(sessionOptions, nThreads);
0026   }
0027 
0028   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
0029     // objects to load the graph
0030     Status status;
0031     RunOptions runOptions;
0032     SavedModelBundle bundle;
0033 
0034     // load the model
0035     status = LoadSavedModel(sessionOptions, runOptions, exportDir, {tag}, &bundle);
0036     if (!status.ok()) {
0037       throw cms::Exception("InvalidMetaGraphDef")
0038           << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
0039     }
0040 
0041     // return a copy of the graph
0042     return new MetaGraphDef(bundle.meta_graph_def);
0043   }
0044 
0045   MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
0046     edm::LogInfo("PhysicsTools/TensorFlow")
0047         << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
0048 
0049     return loadMetaGraphDef(exportDir, tag, sessionOptions);
0050   }
0051 
0052   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, int nThreads) {
0053     // create session options and set thread options
0054     SessionOptions sessionOptions;
0055     setThreading(sessionOptions, nThreads);
0056 
0057     return loadMetaGraphDef(exportDir, tag, sessionOptions);
0058   }
0059 
0060   MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, int nThreads) {
0061     edm::LogInfo("PhysicsTools/TensorFlow")
0062         << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
0063 
0064     return loadMetaGraphDef(exportDir, tag, nThreads);
0065   }
0066 
0067   GraphDef* loadGraphDef(const std::string& pbFile) {
0068     // objects to load the graph
0069     Status status;
0070 
0071     // load it
0072     GraphDef* graphDef = new GraphDef();
0073     status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
0074 
0075     // check for success
0076     if (!status.ok()) {
0077       throw cms::Exception("InvalidGraphDef")
0078           << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
0079     }
0080 
0081     return graphDef;
0082   }
0083 
0084   Session* createSession(SessionOptions& sessionOptions) {
0085     // objects to create the session
0086     Status status;
0087 
0088     // create a new, empty session
0089     Session* session = nullptr;
0090     status = NewSession(sessionOptions, &session);
0091     if (!status.ok()) {
0092       throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
0093     }
0094 
0095     return session;
0096   }
0097 
0098   Session* createSession(int nThreads) {
0099     // create session options and set thread options
0100     SessionOptions sessionOptions;
0101     setThreading(sessionOptions, nThreads);
0102 
0103     return createSession(sessionOptions);
0104   }
0105 
0106   Session* createSession(const MetaGraphDef* metaGraphDef,
0107                          const std::string& exportDir,
0108                          SessionOptions& sessionOptions) {
0109     // check for valid pointer
0110     if (metaGraphDef == nullptr) {
0111       throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
0112     }
0113 
0114     // check that the graph has nodes
0115     if (metaGraphDef->graph_def().node_size() <= 0) {
0116       throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
0117     }
0118 
0119     Session* session = createSession(sessionOptions);
0120 
0121     // add the graph def from the meta graph
0122     Status status;
0123     status = session->Create(metaGraphDef->graph_def());
0124     if (!status.ok()) {
0125       throw cms::Exception("InvalidMetaGraphDef")
0126           << "error while attaching metaGraphDef to session: " << status.ToString();
0127     }
0128 
0129     // restore variables using the variable and index files in the export directory
0130     // first, find names and paths
0131     std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
0132     std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
0133     std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
0134     std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
0135     std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
0136 
0137     // when the index file is missing, there's nothing to do
0138     if (!Env::Default()->FileExists(indexFile).ok()) {
0139       return session;
0140     }
0141 
0142     // create a tensor to store the variable file
0143     Tensor varFileTensor(DT_STRING, TensorShape({}));
0144     varFileTensor.scalar<tensorflow::tstring>()() = varFile;
0145 
0146     // run the restore op
0147     status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
0148     if (!status.ok()) {
0149       throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
0150     }
0151 
0152     return session;
0153   }
0154 
0155   Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads) {
0156     // create session options and set thread options
0157     SessionOptions sessionOptions;
0158     setThreading(sessionOptions, nThreads);
0159 
0160     return createSession(metaGraphDef, exportDir, sessionOptions);
0161   }
0162 
0163   Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions) {
0164     // check for valid pointer
0165     if (graphDef == nullptr) {
0166       throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
0167     }
0168 
0169     // check that the graph has nodes
0170     if (graphDef->node_size() <= 0) {
0171       throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
0172     }
0173 
0174     // create a new, empty session
0175     Session* session = createSession(sessionOptions);
0176 
0177     // add the graph def
0178     Status status;
0179     status = session->Create(*graphDef);
0180 
0181     // check for success
0182     if (!status.ok()) {
0183       throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
0184     }
0185 
0186     return session;
0187   }
0188 
0189   Session* createSession(const GraphDef* graphDef, int nThreads) {
0190     // create session options and set thread options
0191     SessionOptions sessionOptions;
0192     setThreading(sessionOptions, nThreads);
0193 
0194     return createSession(graphDef, sessionOptions);
0195   }
0196 
0197   bool closeSession(Session*& session) {
0198     if (session == nullptr) {
0199       return true;
0200     }
0201 
0202     // close and delete the session
0203     Status status = session->Close();
0204     delete session;
0205 
0206     // reset the pointer
0207     session = nullptr;
0208 
0209     return status.ok();
0210   }
0211 
0212   bool closeSession(const Session*& session) {
0213     auto s = const_cast<Session*>(session);
0214     bool state = closeSession(s);
0215 
0216     // reset the pointer
0217     session = nullptr;
0218 
0219     return state;
0220   }
0221 
0222   void run(Session* session,
0223            const NamedTensorList& inputs,
0224            const std::vector<std::string>& outputNames,
0225            std::vector<Tensor>* outputs,
0226            const thread::ThreadPoolOptions& threadPoolOptions) {
0227     if (session == nullptr) {
0228       throw cms::Exception("InvalidSession") << "cannot run empty session";
0229     }
0230 
0231     // create empty run options
0232     RunOptions runOptions;
0233 
0234     // run and check the status
0235     Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
0236     if (!status.ok()) {
0237       throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
0238     }
0239   }
0240 
0241   void run(Session* session,
0242            const NamedTensorList& inputs,
0243            const std::vector<std::string>& outputNames,
0244            std::vector<Tensor>* outputs,
0245            thread::ThreadPoolInterface* threadPool) {
0246     // create thread pool options
0247     thread::ThreadPoolOptions threadPoolOptions;
0248     threadPoolOptions.inter_op_threadpool = threadPool;
0249     threadPoolOptions.intra_op_threadpool = threadPool;
0250 
0251     // run
0252     run(session, inputs, outputNames, outputs, threadPoolOptions);
0253   }
0254 
0255   void run(Session* session,
0256            const NamedTensorList& inputs,
0257            const std::vector<std::string>& outputNames,
0258            std::vector<Tensor>* outputs,
0259            const std::string& threadPoolName) {
0260     // lookup the thread pool and forward the call accordingly
0261     if (threadPoolName == "no_threads") {
0262       run(session, inputs, outputNames, outputs, &NoThreadPool::instance());
0263     } else if (threadPoolName == "tbb") {
0264       // the TBBTreadPool singleton should be already initialized before with a number of threads
0265       run(session, inputs, outputNames, outputs, &TBBThreadPool::instance());
0266     } else if (threadPoolName == "tensorflow") {
0267       run(session, inputs, outputNames, outputs, nullptr);
0268     } else {
0269       throw cms::Exception("UnknownThreadPool")
0270           << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
0271     }
0272   }
0273 
0274   void run(Session* session,
0275            const std::vector<std::string>& outputNames,
0276            std::vector<Tensor>* outputs,
0277            const std::string& threadPoolName) {
0278     run(session, {}, outputNames, outputs, threadPoolName);
0279   }
0280 
0281   void SessionCache::closeSession() {
0282     // delete the session if set
0283     Session* s = session.load();
0284     if (s != nullptr) {
0285       tensorflow::closeSession(s);
0286       session.store(nullptr);
0287     }
0288 
0289     // delete the graph if set
0290     if (graph.load() != nullptr) {
0291       delete graph.load();
0292       graph.store(nullptr);
0293     }
0294   }
0295 
0296 }  // namespace tensorflow