Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-12-13 23:50:20

0001 /*
0002  * TensorFlow interface helpers.
0003  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0004  *
0005  * Author: Marcel Rieger
0006  */
0007 
0008 #ifndef PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
0009 #define PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
0010 
0011 #include "tensorflow/core/framework/tensor.h"
0012 #include "tensorflow/core/lib/core/threadpool.h"
0013 #include "tensorflow/core/lib/io/path.h"
0014 #include "tensorflow/core/public/session.h"
0015 #include "tensorflow/core/util/tensor_bundle/naming.h"
0016 #include "tensorflow/cc/client/client_session.h"
0017 #include "tensorflow/cc/saved_model/loader.h"
0018 #include "tensorflow/cc/saved_model/constants.h"
0019 #include "tensorflow/cc/saved_model/tag_constants.h"
0020 
0021 #include "PhysicsTools/TensorFlow/interface/NoThreadPool.h"
0022 #include "PhysicsTools/TensorFlow/interface/TBBThreadPool.h"
0023 
0024 #include "FWCore/Utilities/interface/Exception.h"
0025 
0026 namespace tensorflow {
0027 
0028   typedef std::pair<std::string, Tensor> NamedTensor;
0029   typedef std::vector<NamedTensor> NamedTensorList;
0030 
0031   // set the tensorflow log level
0032   void setLogging(const std::string& level = "3");
0033 
0034   // updates the config of sessionOptions so that it uses nThreads
0035   void setThreading(SessionOptions& sessionOptions, int nThreads = 1);
0036 
0037   // deprecated
0038   // updates the config of sessionOptions so that it uses nThreads, prints a deprecation warning
0039   // since the threading configuration is done per run() call as of 2.1
0040   void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool);
0041 
0042   // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
0043   // predefined sessionOptions
0044   // transfers ownership
0045   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions);
0046 
0047   // deprecated in favor of loadMetaGraphDef
0048   MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions);
0049 
0050   // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
0051   // nThreads
0052   // transfers ownership
0053   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir,
0054                                  const std::string& tag = kSavedModelTagServe,
0055                                  int nThreads = 1);
0056 
0057   // deprecated in favor of loadMetaGraphDef
0058   MetaGraphDef* loadMetaGraph(const std::string& exportDir,
0059                               const std::string& tag = kSavedModelTagServe,
0060                               int nThreads = 1);
0061 
0062   // loads a graph definition saved as a protobuf file at pbFile
0063   // transfers ownership
0064   GraphDef* loadGraphDef(const std::string& pbFile);
0065 
0066   // return a new, empty session using predefined sessionOptions
0067   // transfers ownership
0068   Session* createSession(SessionOptions& sessionOptions);
0069 
0070   // return a new, empty session with nThreads
0071   // transfers ownership
0072   Session* createSession(int nThreads = 1);
0073 
0074   // return a new session that will contain an already loaded meta graph whose exportDir must be
0075   // given in order to load and initialize the variables, sessionOptions are predefined
0076   // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
0077   // transfers ownership
0078   Session* createSession(const MetaGraphDef* metaGraphDef,
0079                          const std::string& exportDir,
0080                          SessionOptions& sessionOptions);
0081 
0082   // return a new session that will contain an already loaded meta graph whose exportDir must be given
0083   // in order to load and initialize the variables, threading options are inferred from nThreads
0084   // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
0085   // transfers ownership
0086   Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads = 1);
0087 
0088   // return a new session that will contain an already loaded graph def, sessionOptions are predefined
0089   // an error is thrown when graphDef is a nullptr or when the graph has no nodes
0090   // transfers ownership
0091   Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions);
0092 
0093   // return a new session that will contain an already loaded graph def, threading options are
0094   // inferred from nThreads
0095   // an error is thrown when graphDef is a nullptr or when the graph has no nodes
0096   // transfers ownership
0097   Session* createSession(const GraphDef* graphDef, int nThreads = 1);
0098 
0099   // closes a session, calls its destructor, resets the pointer, and returns true on success
0100   bool closeSession(Session*& session);
0101 
0102   // version of the function above that accepts a const session
0103   bool closeSession(const Session*& session);
0104 
0105   // run the session with inputs and outputNames, store output tensors, and control the underlying
0106   // thread pool using threadPoolOptions
0107   // used for thread scheduling with custom thread pool options
0108   // throws a cms exception when not successful
0109   void run(Session* session,
0110            const NamedTensorList& inputs,
0111            const std::vector<std::string>& outputNames,
0112            std::vector<Tensor>* outputs,
0113            const thread::ThreadPoolOptions& threadPoolOptions);
0114 
0115   // version of the function above that accepts a const session
0116   inline void run(const Session* session,
0117                   const NamedTensorList& inputs,
0118                   const std::vector<std::string>& outputNames,
0119                   std::vector<Tensor>* outputs,
0120                   const thread::ThreadPoolOptions& threadPoolOptions) {
0121     // TF takes a non-const session in the run call which is, however, thread-safe and logically
0122     // const, thus const_cast is consistent
0123     run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolOptions);
0124   }
0125 
0126   // run the session with inputs and outputNames, store output tensors, and control the underlying
0127   // thread pool
0128   // throws a cms exception when not successful
0129   void run(Session* session,
0130            const NamedTensorList& inputs,
0131            const std::vector<std::string>& outputNames,
0132            std::vector<Tensor>* outputs,
0133            thread::ThreadPoolInterface* threadPool);
0134 
0135   // version of the function above that accepts a const session
0136   inline void run(const Session* session,
0137                   const NamedTensorList& inputs,
0138                   const std::vector<std::string>& outputNames,
0139                   std::vector<Tensor>* outputs,
0140                   thread::ThreadPoolInterface* threadPool) {
0141     // TF takes a non-const session in the run call which is, however, thread-safe and logically
0142     // const, thus const_cast is consistent
0143     run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPool);
0144   }
0145 
0146   // run the session with inputs and outputNames, store output tensors, and control the underlying
0147   // thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
0148   // throws a cms exception when not successful
0149   void run(Session* session,
0150            const NamedTensorList& inputs,
0151            const std::vector<std::string>& outputNames,
0152            std::vector<Tensor>* outputs,
0153            const std::string& threadPoolName = "no_threads");
0154 
0155   // version of the function above that accepts a const session
0156   inline void run(const Session* session,
0157                   const NamedTensorList& inputs,
0158                   const std::vector<std::string>& outputNames,
0159                   std::vector<Tensor>* outputs,
0160                   const std::string& threadPoolName = "no_threads") {
0161     // TF takes a non-const session in the run call which is, however, thread-safe and logically
0162     // const, thus const_cast is consistent
0163     run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolName);
0164   }
0165 
0166   // run the session without inputs but only outputNames, store output tensors, and control the
0167   // underlying thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
0168   // throws a cms exception when not successful
0169   void run(Session* session,
0170            const std::vector<std::string>& outputNames,
0171            std::vector<Tensor>* outputs,
0172            const std::string& threadPoolName = "no_threads");
0173 
0174   // version of the function above that accepts a const session
0175   inline void run(const Session* session,
0176                   const std::vector<std::string>& outputNames,
0177                   std::vector<Tensor>* outputs,
0178                   const std::string& threadPoolName = "no_threads") {
0179     // TF takes a non-const session in the run call which is, however, thread-safe and logically
0180     // const, thus const_cast is consistent
0181     run(const_cast<Session*>(session), outputNames, outputs, threadPoolName);
0182   }
0183 
0184   // struct that can be used in edm::stream modules for caching a graph and a session instance,
0185   // both made atomic for cases where access is required from multiple threads
0186   struct SessionCache {
0187     std::atomic<GraphDef*> graph;
0188     std::atomic<Session*> session;
0189 
0190     // constructor
0191     SessionCache() {}
0192 
0193     // initializing constructor, forwarding all arguments to createSession
0194     template <typename... Args>
0195     SessionCache(const std::string& graphPath, Args&&... sessionArgs) {
0196       createSession(graphPath, std::forward<Args>(sessionArgs)...);
0197     }
0198 
0199     // destructor
0200     ~SessionCache() { closeSession(); }
0201 
0202     // create the internal graph representation from graphPath and the session object, forwarding
0203     // all additional arguments to the central tensorflow::createSession
0204     template <typename... Args>
0205     void createSession(const std::string& graphPath, Args&&... sessionArgs) {
0206       graph.store(loadGraphDef(graphPath));
0207       session.store(tensorflow::createSession(graph.load(), std::forward<Args>(sessionArgs)...));
0208     }
0209 
0210     // return a pointer to the const session
0211     inline const Session* getSession() const { return session.load(); }
0212 
0213     // closes and removes the session as well as the graph, and sets the atomic members to nullptr's
0214     void closeSession();
0215   };
0216 
0217 }  // namespace tensorflow
0218 
0219 #endif  // PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H