File indexing completed on 2021-02-14 13:33:24
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0010
0011 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0012
0013 namespace tensorflow {
0014
0015 void setLogging(const std::string& level) { setenv("TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0); }
0016
0017 void setThreading(SessionOptions& sessionOptions, int nThreads) {
0018
0019 sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
0020 sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
0021 }
0022
0023 void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool) {
0024 edm::LogInfo("PhysicsTools/TensorFlow") << "setting the thread pool via tensorflow::setThreading() is deprecated";
0025
0026 setThreading(sessionOptions, nThreads);
0027 }
0028
0029 MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
0030
0031 Status status;
0032 RunOptions runOptions;
0033 SavedModelBundle bundle;
0034
0035
0036 status = LoadSavedModel(sessionOptions, runOptions, exportDir, {tag}, &bundle);
0037 if (!status.ok()) {
0038 throw cms::Exception("InvalidMetaGraphDef")
0039 << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
0040 }
0041
0042
0043 return new MetaGraphDef(bundle.meta_graph_def);
0044 }
0045
0046 MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
0047 edm::LogInfo("PhysicsTools/TensorFlow")
0048 << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
0049
0050 return loadMetaGraphDef(exportDir, tag, sessionOptions);
0051 }
0052
0053 MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, int nThreads) {
0054
0055 SessionOptions sessionOptions;
0056 setThreading(sessionOptions, nThreads);
0057
0058 return loadMetaGraphDef(exportDir, tag, sessionOptions);
0059 }
0060
0061 MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, int nThreads) {
0062 edm::LogInfo("PhysicsTools/TensorFlow")
0063 << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
0064
0065 return loadMetaGraphDef(exportDir, tag, nThreads);
0066 }
0067
0068 GraphDef* loadGraphDef(const std::string& pbFile) {
0069
0070 Status status;
0071
0072
0073 GraphDef* graphDef = new GraphDef();
0074 status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
0075
0076
0077 if (!status.ok()) {
0078 throw cms::Exception("InvalidGraphDef")
0079 << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
0080 }
0081
0082 return graphDef;
0083 }
0084
0085 Session* createSession(SessionOptions& sessionOptions) {
0086
0087 Status status;
0088
0089
0090 Session* session = nullptr;
0091 status = NewSession(sessionOptions, &session);
0092 if (!status.ok()) {
0093 throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
0094 }
0095
0096 return session;
0097 }
0098
0099 Session* createSession(int nThreads) {
0100
0101 SessionOptions sessionOptions;
0102 setThreading(sessionOptions, nThreads);
0103
0104 return createSession(sessionOptions);
0105 }
0106
0107 Session* createSession(const MetaGraphDef* metaGraphDef,
0108 const std::string& exportDir,
0109 SessionOptions& sessionOptions) {
0110
0111 if (metaGraphDef == nullptr) {
0112 throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
0113 }
0114
0115
0116 if (metaGraphDef->graph_def().node_size() <= 0) {
0117 throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
0118 }
0119
0120 Session* session = createSession(sessionOptions);
0121
0122
0123 Status status;
0124 status = session->Create(metaGraphDef->graph_def());
0125 if (!status.ok()) {
0126 throw cms::Exception("InvalidMetaGraphDef")
0127 << "error while attaching metaGraphDef to session: " << status.ToString();
0128 }
0129
0130
0131
0132 std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
0133 std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
0134 std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
0135 std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
0136 std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
0137
0138
0139 if (!Env::Default()->FileExists(indexFile).ok()) {
0140 return session;
0141 }
0142
0143
0144 Tensor varFileTensor(DT_STRING, TensorShape({}));
0145 varFileTensor.scalar<tensorflow::tstring>()() = varFile;
0146
0147
0148 status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
0149 if (!status.ok()) {
0150 throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
0151 }
0152
0153 return session;
0154 }
0155
0156 Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads) {
0157
0158 SessionOptions sessionOptions;
0159 setThreading(sessionOptions, nThreads);
0160
0161 return createSession(metaGraphDef, exportDir, sessionOptions);
0162 }
0163
0164 Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions) {
0165
0166 if (graphDef == nullptr) {
0167 throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
0168 }
0169
0170
0171 if (graphDef->node_size() <= 0) {
0172 throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
0173 }
0174
0175
0176 Session* session = createSession(sessionOptions);
0177
0178
0179 Status status;
0180 status = session->Create(*graphDef);
0181
0182
0183 if (!status.ok()) {
0184 throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
0185 }
0186
0187 return session;
0188 }
0189
0190 Session* createSession(const GraphDef* graphDef, int nThreads) {
0191
0192 SessionOptions sessionOptions;
0193 setThreading(sessionOptions, nThreads);
0194
0195 return createSession(graphDef, sessionOptions);
0196 }
0197
0198 bool closeSession(Session*& session) {
0199 if (session == nullptr) {
0200 return true;
0201 }
0202
0203
0204 Status status = session->Close();
0205 delete session;
0206
0207
0208 session = nullptr;
0209
0210 return status.ok();
0211 }
0212
0213 void run(Session* session,
0214 const NamedTensorList& inputs,
0215 const std::vector<std::string>& outputNames,
0216 std::vector<Tensor>* outputs,
0217 const thread::ThreadPoolOptions& threadPoolOptions) {
0218 if (session == nullptr) {
0219 throw cms::Exception("InvalidSession") << "cannot run empty session";
0220 }
0221
0222
0223 RunOptions runOptions;
0224
0225
0226 Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
0227 if (!status.ok()) {
0228 throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
0229 }
0230 }
0231
0232 void run(Session* session,
0233 const NamedTensorList& inputs,
0234 const std::vector<std::string>& outputNames,
0235 std::vector<Tensor>* outputs,
0236 thread::ThreadPoolInterface* threadPool) {
0237
0238 thread::ThreadPoolOptions threadPoolOptions;
0239 threadPoolOptions.inter_op_threadpool = threadPool;
0240 threadPoolOptions.intra_op_threadpool = threadPool;
0241
0242
0243 run(session, inputs, outputNames, outputs, threadPoolOptions);
0244 }
0245
0246 void run(Session* session,
0247 const NamedTensorList& inputs,
0248 const std::vector<std::string>& outputNames,
0249 std::vector<Tensor>* outputs,
0250 const std::string& threadPoolName) {
0251
0252 if (threadPoolName == "no_threads") {
0253 run(session, inputs, outputNames, outputs, &NoThreadPool::instance());
0254 } else if (threadPoolName == "tbb") {
0255
0256 run(session, inputs, outputNames, outputs, &TBBThreadPool::instance());
0257 } else if (threadPoolName == "tensorflow") {
0258 run(session, inputs, outputNames, outputs, nullptr);
0259 } else {
0260 throw cms::Exception("UnknownThreadPool")
0261 << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
0262 }
0263 }
0264
0265 void run(Session* session,
0266 const std::vector<std::string>& outputNames,
0267 std::vector<Tensor>* outputs,
0268 const std::string& threadPoolName) {
0269 run(session, {}, outputNames, outputs, threadPoolName);
0270 }
0271
0272 }