Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-06-07 02:29:49

0001 /*
0002  * TensorFlow interface helpers.
0003  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0004  *
0005  * Author: Marcel Rieger
0006  */
0007 
0008 #ifndef PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
0009 #define PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
0010 
0011 #include "tensorflow/core/framework/tensor.h"
0012 #include "tensorflow/core/lib/core/threadpool.h"
0013 #include "tensorflow/core/lib/io/path.h"
0014 #include "tensorflow/core/public/session.h"
0015 #include "tensorflow/core/util/tensor_bundle/naming.h"
0016 #include "tensorflow/cc/client/client_session.h"
0017 #include "tensorflow/cc/saved_model/loader.h"
0018 #include "tensorflow/cc/saved_model/constants.h"
0019 #include "tensorflow/cc/saved_model/tag_constants.h"
0020 
0021 #include "PhysicsTools/TensorFlow/interface/NoThreadPool.h"
0022 #include "PhysicsTools/TensorFlow/interface/TBBThreadPool.h"
0023 
0024 #include "FWCore/Utilities/interface/Exception.h"
0025 
0026 namespace tensorflow {
0027 
0028   enum class Backend { cpu, cuda, rocm, intel, best };
0029 
0030   typedef std::pair<std::string, Tensor> NamedTensor;
0031   typedef std::vector<NamedTensor> NamedTensorList;
0032 
0033   struct Options {
0034     int _nThreads;
0035     Backend _backend;
0036     SessionOptions _options;
0037 
0038     Options(Backend backend) : _nThreads{1}, _backend{backend} {
0039       setThreading(_nThreads);
0040       setBackend(_backend);
0041     };
0042 
0043     Options() : _nThreads{1}, _backend{Backend::cpu} {
0044       setThreading(_nThreads);
0045       setBackend(_backend);
0046     };
0047 
0048     // updates the config of sessionOptions so that it uses nThreads
0049     void setThreading(int nThreads = 1);
0050 
0051     // Set the backend option cpu/cuda
0052     // The gpu memory is set to "allow_growth" to avoid TF getting all the CUDA memory at once.
0053     void setBackend(Backend backend = Backend::cpu);
0054 
0055     SessionOptions& getSessionOptions() { return _options; };
0056     int getNThreads() const { return _nThreads; };
0057     Backend getBackend() const { return _backend; };
0058   };
0059 
0060   // set the tensorflow log level
0061   void setLogging(const std::string& level = "3");
0062 
0063   // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
0064   // predefined options
0065   // transfers ownership
0066   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag = kSavedModelTagServe);
0067 
0068   // loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
0069   // user provided options
0070   // transfers ownership
0071   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, Options& options);
0072 
0073   // deprecated in favor of loadMetaGraphDef
0074   MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, Options& Options);
0075 
0076   // loads a graph definition saved as a protobuf file at pbFile
0077   // transfers ownership
0078   GraphDef* loadGraphDef(const std::string& pbFile);
0079 
0080   // return a new, empty session using the predefined options
0081   Session* createSession();
0082 
0083   // return a new, empty session using user provided options
0084   // transfers ownership
0085   Session* createSession(Options& options);
0086 
0087   // return a new session that will contain an already loaded meta graph whose exportDir must be
0088   // given in order to load and initialize the variables, sessionOptions are predefined
0089   // an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
0090   // transfers ownership
0091   Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, Options& options);
0092 
0093   // return a new session that will contain an already loaded graph def, sessionOptions are predefined
0094   // an error is thrown when graphDef is a nullptr or when the graph has no nodes
0095   // transfers ownership
0096   Session* createSession(const GraphDef* graphDef);
0097 
0098   // return a new session that will contain an already loaded graph def, sessionOptions are user defined
0099   // an error is thrown when graphDef is a nullptr or when the graph has no nodes
0100   // transfers ownership
0101   Session* createSession(const GraphDef* graphDef, Options& options);
0102 
0103   // closes a session, calls its destructor, resets the pointer, and returns true on success
0104   bool closeSession(Session*& session);
0105 
0106   // version of the function above that accepts a const session
0107   bool closeSession(const Session*& session);
0108 
0109   bool checkEmptyInputs(const NamedTensorList& inputs);
0110 
0111   // run the session with inputs and outputNames, store output tensors, and control the underlying
0112   // thread pool using threadPoolOptions
0113   // used for thread scheduling with custom thread pool options
0114   // throws a cms exception when not successful
0115   void run(Session* session,
0116            const NamedTensorList& inputs,
0117            const std::vector<std::string>& outputNames,
0118            std::vector<Tensor>* outputs,
0119            const thread::ThreadPoolOptions& threadPoolOptions);
0120 
0121   // version of the function above that accepts a const session
0122   inline void run(const Session* session,
0123                   const NamedTensorList& inputs,
0124                   const std::vector<std::string>& outputNames,
0125                   std::vector<Tensor>* outputs,
0126                   const thread::ThreadPoolOptions& threadPoolOptions) {
0127     // TF takes a non-const session in the run call which is, however, thread-safe and logically
0128     // const, thus const_cast is consistent
0129     run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolOptions);
0130   }
0131 
0132   // run the session with inputs and outputNames, store output tensors, and control the underlying
0133   // thread pool
0134   // throws a cms exception when not successful
0135   void run(Session* session,
0136            const NamedTensorList& inputs,
0137            const std::vector<std::string>& outputNames,
0138            std::vector<Tensor>* outputs,
0139            thread::ThreadPoolInterface* threadPool);
0140 
0141   // version of the function above that accepts a const session
0142   inline void run(const Session* session,
0143                   const NamedTensorList& inputs,
0144                   const std::vector<std::string>& outputNames,
0145                   std::vector<Tensor>* outputs,
0146                   thread::ThreadPoolInterface* threadPool) {
0147     // TF takes a non-const session in the run call which is, however, thread-safe and logically
0148     // const, thus const_cast is consistent
0149     run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPool);
0150   }
0151 
0152   // run the session with inputs and outputNames, store output tensors, and control the underlying
0153   // thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
0154   // throws a cms exception when not successful
0155   void run(Session* session,
0156            const NamedTensorList& inputs,
0157            const std::vector<std::string>& outputNames,
0158            std::vector<Tensor>* outputs,
0159            const std::string& threadPoolName = "no_threads");
0160 
0161   // version of the function above that accepts a const session
0162   inline void run(const Session* session,
0163                   const NamedTensorList& inputs,
0164                   const std::vector<std::string>& outputNames,
0165                   std::vector<Tensor>* outputs,
0166                   const std::string& threadPoolName = "no_threads") {
0167     // TF takes a non-const session in the run call which is, however, thread-safe and logically
0168     // const, thus const_cast is consistent
0169     run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolName);
0170   }
0171 
0172   // run the session without inputs but only outputNames, store output tensors, and control the
0173   // underlying thread pool using a threadPoolName ("no_threads", "tbb", or "tensorflow")
0174   // throws a cms exception when not successful
0175   void run(Session* session,
0176            const std::vector<std::string>& outputNames,
0177            std::vector<Tensor>* outputs,
0178            const std::string& threadPoolName = "no_threads");
0179 
0180   // version of the function above that accepts a const session
0181   inline void run(const Session* session,
0182                   const std::vector<std::string>& outputNames,
0183                   std::vector<Tensor>* outputs,
0184                   const std::string& threadPoolName = "no_threads") {
0185     // TF takes a non-const session in the run call which is, however, thread-safe and logically
0186     // const, thus const_cast is consistent
0187     run(const_cast<Session*>(session), outputNames, outputs, threadPoolName);
0188   }
0189 
0190   // struct that can be used in edm::stream modules for caching a graph and a session instance,
0191   // both made atomic for cases where access is required from multiple threads
0192   struct SessionCache {
0193     std::atomic<GraphDef*> graph;
0194     std::atomic<Session*> session;
0195 
0196     // constructor
0197     SessionCache() {}
0198 
0199     // initializing constructor, forwarding all arguments to createSession
0200     template <typename... Args>
0201     SessionCache(const std::string& graphPath, Args&&... sessionArgs) {
0202       createSession(graphPath, std::forward<Args>(sessionArgs)...);
0203     }
0204 
0205     // destructor
0206     ~SessionCache() { closeSession(); }
0207 
0208     // create the internal graph representation from graphPath and the session object, forwarding
0209     // all additional arguments to the central tensorflow::createSession
0210     template <typename... Args>
0211     void createSession(const std::string& graphPath, Args&&... sessionArgs) {
0212       graph.store(loadGraphDef(graphPath));
0213       session.store(tensorflow::createSession(graph.load(), std::forward<Args>(sessionArgs)...));
0214     }
0215 
0216     // return a pointer to the const session
0217     inline const Session* getSession() const { return session.load(); }
0218 
0219     // closes and removes the session as well as the graph, and sets the atomic members to nullptr's
0220     void closeSession();
0221   };
0222 
0223 }  // namespace tensorflow
0224 
0225 #endif  // PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H