File indexing completed on 2022-12-13 23:50:20
0001
0002
0003
0004
0005
0006
0007
0008 #ifndef PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
0009 #define PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
0010
0011 #include "tensorflow/core/framework/tensor.h"
0012 #include "tensorflow/core/lib/core/threadpool.h"
0013 #include "tensorflow/core/lib/io/path.h"
0014 #include "tensorflow/core/public/session.h"
0015 #include "tensorflow/core/util/tensor_bundle/naming.h"
0016 #include "tensorflow/cc/client/client_session.h"
0017 #include "tensorflow/cc/saved_model/loader.h"
0018 #include "tensorflow/cc/saved_model/constants.h"
0019 #include "tensorflow/cc/saved_model/tag_constants.h"
0020
0021 #include "PhysicsTools/TensorFlow/interface/NoThreadPool.h"
0022 #include "PhysicsTools/TensorFlow/interface/TBBThreadPool.h"
0023
0024 #include "FWCore/Utilities/interface/Exception.h"
0025
0026 namespace tensorflow {
0027
0028 typedef std::pair<std::string, Tensor> NamedTensor;
0029 typedef std::vector<NamedTensor> NamedTensorList;
0030
0031
0032 void setLogging(const std::string& level = "3");
0033
0034
0035 void setThreading(SessionOptions& sessionOptions, int nThreads = 1);
0036
0037
0038
0039
0040 void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool);
0041
0042
0043
0044
0045 MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions);
0046
0047
0048 MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions);
0049
0050
0051
0052
0053 MetaGraphDef* loadMetaGraphDef(const std::string& exportDir,
0054 const std::string& tag = kSavedModelTagServe,
0055 int nThreads = 1);
0056
0057
0058 MetaGraphDef* loadMetaGraph(const std::string& exportDir,
0059 const std::string& tag = kSavedModelTagServe,
0060 int nThreads = 1);
0061
0062
0063
0064 GraphDef* loadGraphDef(const std::string& pbFile);
0065
0066
0067
0068 Session* createSession(SessionOptions& sessionOptions);
0069
0070
0071
0072 Session* createSession(int nThreads = 1);
0073
0074
0075
0076
0077
0078 Session* createSession(const MetaGraphDef* metaGraphDef,
0079 const std::string& exportDir,
0080 SessionOptions& sessionOptions);
0081
0082
0083
0084
0085
0086 Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads = 1);
0087
0088
0089
0090
0091 Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions);
0092
0093
0094
0095
0096
0097 Session* createSession(const GraphDef* graphDef, int nThreads = 1);
0098
0099
0100 bool closeSession(Session*& session);
0101
0102
0103 bool closeSession(const Session*& session);
0104
0105
0106
0107
0108
0109 void run(Session* session,
0110 const NamedTensorList& inputs,
0111 const std::vector<std::string>& outputNames,
0112 std::vector<Tensor>* outputs,
0113 const thread::ThreadPoolOptions& threadPoolOptions);
0114
0115
0116 inline void run(const Session* session,
0117 const NamedTensorList& inputs,
0118 const std::vector<std::string>& outputNames,
0119 std::vector<Tensor>* outputs,
0120 const thread::ThreadPoolOptions& threadPoolOptions) {
0121
0122
0123 run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolOptions);
0124 }
0125
0126
0127
0128
0129 void run(Session* session,
0130 const NamedTensorList& inputs,
0131 const std::vector<std::string>& outputNames,
0132 std::vector<Tensor>* outputs,
0133 thread::ThreadPoolInterface* threadPool);
0134
0135
0136 inline void run(const Session* session,
0137 const NamedTensorList& inputs,
0138 const std::vector<std::string>& outputNames,
0139 std::vector<Tensor>* outputs,
0140 thread::ThreadPoolInterface* threadPool) {
0141
0142
0143 run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPool);
0144 }
0145
0146
0147
0148
0149 void run(Session* session,
0150 const NamedTensorList& inputs,
0151 const std::vector<std::string>& outputNames,
0152 std::vector<Tensor>* outputs,
0153 const std::string& threadPoolName = "no_threads");
0154
0155
0156 inline void run(const Session* session,
0157 const NamedTensorList& inputs,
0158 const std::vector<std::string>& outputNames,
0159 std::vector<Tensor>* outputs,
0160 const std::string& threadPoolName = "no_threads") {
0161
0162
0163 run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolName);
0164 }
0165
0166
0167
0168
0169 void run(Session* session,
0170 const std::vector<std::string>& outputNames,
0171 std::vector<Tensor>* outputs,
0172 const std::string& threadPoolName = "no_threads");
0173
0174
0175 inline void run(const Session* session,
0176 const std::vector<std::string>& outputNames,
0177 std::vector<Tensor>* outputs,
0178 const std::string& threadPoolName = "no_threads") {
0179
0180
0181 run(const_cast<Session*>(session), outputNames, outputs, threadPoolName);
0182 }
0183
0184
0185
0186 struct SessionCache {
0187 std::atomic<GraphDef*> graph;
0188 std::atomic<Session*> session;
0189
0190
0191 SessionCache() {}
0192
0193
0194 template <typename... Args>
0195 SessionCache(const std::string& graphPath, Args&&... sessionArgs) {
0196 createSession(graphPath, std::forward<Args>(sessionArgs)...);
0197 }
0198
0199
0200 ~SessionCache() { closeSession(); }
0201
0202
0203
0204 template <typename... Args>
0205 void createSession(const std::string& graphPath, Args&&... sessionArgs) {
0206 graph.store(loadGraphDef(graphPath));
0207 session.store(tensorflow::createSession(graph.load(), std::forward<Args>(sessionArgs)...));
0208 }
0209
0210
0211 inline const Session* getSession() const { return session.load(); }
0212
0213
0214 void closeSession();
0215 };
0216
0217 }
0218
0219 #endif