Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-02-14 13:33:24

0001 /*
0002  * TensorFlow interface helpers.
0003  * Based on TensorFlow C++ API 2.1.
0004  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0005  *
0006  * Author: Marcel Rieger
0007  */
0008 
0009 #ifndef PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
0010 #define PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
0011 
0012 #include "tensorflow/core/framework/tensor.h"
0013 #include "tensorflow/core/lib/core/threadpool.h"
0014 #include "tensorflow/core/lib/io/path.h"
0015 #include "tensorflow/core/public/session.h"
0016 #include "tensorflow/core/util/tensor_bundle/naming.h"
0017 #include "tensorflow/cc/client/client_session.h"
0018 #include "tensorflow/cc/saved_model/loader.h"
0019 #include "tensorflow/cc/saved_model/constants.h"
0020 #include "tensorflow/cc/saved_model/tag_constants.h"
0021 
0022 #include "PhysicsTools/TensorFlow/interface/NoThreadPool.h"
0023 #include "PhysicsTools/TensorFlow/interface/TBBThreadPool.h"
0024 
0025 #include "FWCore/Utilities/interface/Exception.h"
0026 
0027 namespace tensorflow {
0028 
0029   typedef std::pair<std::string, Tensor> NamedTensor;
0030   typedef std::vector<NamedTensor> NamedTensorList;
0031 
0032   // set the tensorflow log level
0033   void setLogging(const std::string& level = "3");
0034 
0035   // updates the config of sessionOptions so that it uses nThreads
0036   void setThreading(SessionOptions& sessionOptions, int nThreads = 1);
0037 
0038   // deprecated
0039   // updates the config of sessionOptions so that it uses nThreads, prints a deprecation warning
0040   // since the threading configuration is done per run() call as of 2.1
0041   void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool);
0042 
0043   // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
0044   // predefined sessionOptions
0045   // transfers ownership
0046   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions);
0047 
0048   // deprecated in favor of loadMetaGraphDef
0049   MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions);
0050 
0051   // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
0052   // nThreads
0053   // transfers ownership
0054   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir,
0055                                  const std::string& tag = kSavedModelTagServe,
0056                                  int nThreads = 1);
0057 
0058   // deprecated in favor of loadMetaGraphDef
0059   MetaGraphDef* loadMetaGraph(const std::string& exportDir,
0060                               const std::string& tag = kSavedModelTagServe,
0061                               int nThreads = 1);
0062 
0063   // loads a graph definition saved as a protobuf file at pbFile
0064   // transfers ownership
0065   GraphDef* loadGraphDef(const std::string& pbFile);
0066 
0067   // return a new, empty session using predefined sessionOptions
0068   // transfers ownership
0069   Session* createSession(SessionOptions& sessionOptions);
0070 
0071   // return a new, empty session with nThreads
0072   // transfers ownership
0073   Session* createSession(int nThreads = 1);
0074 
0075   // return a new session that will contain an already loaded meta graph whose exportDir must be
0076   // given in order to load and initialize the variables, sessionOptions are predefined
0077   // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
0078   // transfers ownership
0079   Session* createSession(const MetaGraphDef* metaGraphDef,
0080                          const std::string& exportDir,
0081                          SessionOptions& sessionOptions);
0082 
0083   // return a new session that will contain an already loaded meta graph whose exportDir must be given
0084   // in order to load and initialize the variables, threading options are inferred from nThreads
0085   // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
0086   // transfers ownership
0087   Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads = 1);
0088 
0089   // return a new session that will contain an already loaded graph def, sessionOptions are predefined
0090   // an error is thrown when graphDef is a nullptr or when the graph has no nodes
0091   // transfers ownership
0092   Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions);
0093 
0094   // return a new session that will contain an already loaded graph def, threading options are
0095   // inferred from nThreads
0096   // an error is thrown when graphDef is a nullptr or when the graph has no nodes
0097   // transfers ownership
0098   Session* createSession(const GraphDef* graphDef, int nThreads = 1);
0099 
0100   // closes a session, calls its destructor, resets the pointer, and returns true on success
0101   bool closeSession(Session*& session);
0102 
0103   // run the session with inputs and outputNames, store output tensors, and control the underlying
0104   // thread pool using threadPoolOptions
0105   // used for thread scheduling with custom thread pool options
0106   // throws a cms exception when not successful
0107   void run(Session* session,
0108            const NamedTensorList& inputs,
0109            const std::vector<std::string>& outputNames,
0110            std::vector<Tensor>* outputs,
0111            const thread::ThreadPoolOptions& threadPoolOptions);
0112 
0113   // run the session with inputs and outputNames, store output tensors, and control the underlying
0114   // thread pool
0115   // throws a cms exception when not successful
0116   void run(Session* session,
0117            const NamedTensorList& inputs,
0118            const std::vector<std::string>& outputNames,
0119            std::vector<Tensor>* outputs,
0120            thread::ThreadPoolInterface* threadPool);
0121 
0122   // run the session with inputs and outputNames, store output tensors, and control the underlying
0123   // thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
0124   // throws a cms exception when not successful
0125   void run(Session* session,
0126            const NamedTensorList& inputs,
0127            const std::vector<std::string>& outputNames,
0128            std::vector<Tensor>* outputs,
0129            const std::string& threadPoolName = "no_threads");
0130 
0131   // run the session without inputs but only outputNames, store output tensors, and control the
0132   // underlying thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
0133   // throws a cms exception when not successful
0134   void run(Session* session,
0135            const std::vector<std::string>& outputNames,
0136            std::vector<Tensor>* outputs,
0137            const std::string& threadPoolName = "no_threads");
0138 
0139 }  // namespace tensorflow
0140 
0141 #endif  // PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H