File indexing completed on 2024-06-07 02:29:49
0001
0002
0003
0004
0005
0006
0007
0008 #ifndef PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
0009 #define PHYSICSTOOLS_TENSORFLOW_TENSORFLOW_H
0010
0011 #include "tensorflow/core/framework/tensor.h"
0012 #include "tensorflow/core/lib/core/threadpool.h"
0013 #include "tensorflow/core/lib/io/path.h"
0014 #include "tensorflow/core/public/session.h"
0015 #include "tensorflow/core/util/tensor_bundle/naming.h"
0016 #include "tensorflow/cc/client/client_session.h"
0017 #include "tensorflow/cc/saved_model/loader.h"
0018 #include "tensorflow/cc/saved_model/constants.h"
0019 #include "tensorflow/cc/saved_model/tag_constants.h"
0020
0021 #include "PhysicsTools/TensorFlow/interface/NoThreadPool.h"
0022 #include "PhysicsTools/TensorFlow/interface/TBBThreadPool.h"
0023
0024 #include "FWCore/Utilities/interface/Exception.h"
0025
0026 namespace tensorflow {
0027
0028 enum class Backend { cpu, cuda, rocm, intel, best };
0029
0030 typedef std::pair<std::string, Tensor> NamedTensor;
0031 typedef std::vector<NamedTensor> NamedTensorList;
0032
0033 struct Options {
0034 int _nThreads;
0035 Backend _backend;
0036 SessionOptions _options;
0037
0038 Options(Backend backend) : _nThreads{1}, _backend{backend} {
0039 setThreading(_nThreads);
0040 setBackend(_backend);
0041 };
0042
0043 Options() : _nThreads{1}, _backend{Backend::cpu} {
0044 setThreading(_nThreads);
0045 setBackend(_backend);
0046 };
0047
0048
0049 void setThreading(int nThreads = 1);
0050
0051
0052
0053 void setBackend(Backend backend = Backend::cpu);
0054
0055 SessionOptions& getSessionOptions() { return _options; };
0056 int getNThreads() const { return _nThreads; };
0057 Backend getBackend() const { return _backend; };
0058 };
0059
0060
0061 void setLogging(const std::string& level = "3");
0062
0063
0064
0065
0066 MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag = kSavedModelTagServe);
0067
0068
0069
0070
0071 MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, Options& options);
0072
0073
0074 MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, Options& Options);
0075
0076
0077
0078 GraphDef* loadGraphDef(const std::string& pbFile);
0079
0080
0081 Session* createSession();
0082
0083
0084
0085 Session* createSession(Options& options);
0086
0087
0088
0089
0090
0091 Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, Options& options);
0092
0093
0094
0095
0096 Session* createSession(const GraphDef* graphDef);
0097
0098
0099
0100
0101 Session* createSession(const GraphDef* graphDef, Options& options);
0102
0103
0104 bool closeSession(Session*& session);
0105
0106
0107 bool closeSession(const Session*& session);
0108
0109 bool checkEmptyInputs(const NamedTensorList& inputs);
0110
0111
0112
0113
0114
0115 void run(Session* session,
0116 const NamedTensorList& inputs,
0117 const std::vector<std::string>& outputNames,
0118 std::vector<Tensor>* outputs,
0119 const thread::ThreadPoolOptions& threadPoolOptions);
0120
0121
0122 inline void run(const Session* session,
0123 const NamedTensorList& inputs,
0124 const std::vector<std::string>& outputNames,
0125 std::vector<Tensor>* outputs,
0126 const thread::ThreadPoolOptions& threadPoolOptions) {
0127
0128
0129 run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolOptions);
0130 }
0131
0132
0133
0134
0135 void run(Session* session,
0136 const NamedTensorList& inputs,
0137 const std::vector<std::string>& outputNames,
0138 std::vector<Tensor>* outputs,
0139 thread::ThreadPoolInterface* threadPool);
0140
0141
0142 inline void run(const Session* session,
0143 const NamedTensorList& inputs,
0144 const std::vector<std::string>& outputNames,
0145 std::vector<Tensor>* outputs,
0146 thread::ThreadPoolInterface* threadPool) {
0147
0148
0149 run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPool);
0150 }
0151
0152
0153
0154
0155 void run(Session* session,
0156 const NamedTensorList& inputs,
0157 const std::vector<std::string>& outputNames,
0158 std::vector<Tensor>* outputs,
0159 const std::string& threadPoolName = "no_threads");
0160
0161
0162 inline void run(const Session* session,
0163 const NamedTensorList& inputs,
0164 const std::vector<std::string>& outputNames,
0165 std::vector<Tensor>* outputs,
0166 const std::string& threadPoolName = "no_threads") {
0167
0168
0169 run(const_cast<Session*>(session), inputs, outputNames, outputs, threadPoolName);
0170 }
0171
0172
0173
0174
0175 void run(Session* session,
0176 const std::vector<std::string>& outputNames,
0177 std::vector<Tensor>* outputs,
0178 const std::string& threadPoolName = "no_threads");
0179
0180
0181 inline void run(const Session* session,
0182 const std::vector<std::string>& outputNames,
0183 std::vector<Tensor>* outputs,
0184 const std::string& threadPoolName = "no_threads") {
0185
0186
0187 run(const_cast<Session*>(session), outputNames, outputs, threadPoolName);
0188 }
0189
0190
0191
0192 struct SessionCache {
0193 std::atomic<GraphDef*> graph;
0194 std::atomic<Session*> session;
0195
0196
0197 SessionCache() {}
0198
0199
0200 template <typename... Args>
0201 SessionCache(const std::string& graphPath, Args&&... sessionArgs) {
0202 createSession(graphPath, std::forward<Args>(sessionArgs)...);
0203 }
0204
0205
0206 ~SessionCache() { closeSession(); }
0207
0208
0209
0210 template <typename... Args>
0211 void createSession(const std::string& graphPath, Args&&... sessionArgs) {
0212 graph.store(loadGraphDef(graphPath));
0213 session.store(tensorflow::createSession(graph.load(), std::forward<Args>(sessionArgs)...));
0214 }
0215
0216
0217 inline const Session* getSession() const { return session.load(); }
0218
0219
0220 void closeSession();
0221 };
0222
0223 }
0224
0225 #endif