File indexing completed on 2022-12-13 23:50:21
0001
0002
0003
0004
0005
0006
0007
0008 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0009
0010 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0011
0012 namespace tensorflow {
0013
0014 void setLogging(const std::string& level) { setenv("TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0); }
0015
0016 void setThreading(SessionOptions& sessionOptions, int nThreads) {
0017
0018 sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
0019 sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
0020 }
0021
0022 void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool) {
0023 edm::LogInfo("PhysicsTools/TensorFlow") << "setting the thread pool via tensorflow::setThreading() is deprecated";
0024
0025 setThreading(sessionOptions, nThreads);
0026 }
0027
0028 MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
0029
0030 Status status;
0031 RunOptions runOptions;
0032 SavedModelBundle bundle;
0033
0034
0035 status = LoadSavedModel(sessionOptions, runOptions, exportDir, {tag}, &bundle);
0036 if (!status.ok()) {
0037 throw cms::Exception("InvalidMetaGraphDef")
0038 << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
0039 }
0040
0041
0042 return new MetaGraphDef(bundle.meta_graph_def);
0043 }
0044
0045 MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
0046 edm::LogInfo("PhysicsTools/TensorFlow")
0047 << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
0048
0049 return loadMetaGraphDef(exportDir, tag, sessionOptions);
0050 }
0051
0052 MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, int nThreads) {
0053
0054 SessionOptions sessionOptions;
0055 setThreading(sessionOptions, nThreads);
0056
0057 return loadMetaGraphDef(exportDir, tag, sessionOptions);
0058 }
0059
0060 MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, int nThreads) {
0061 edm::LogInfo("PhysicsTools/TensorFlow")
0062 << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
0063
0064 return loadMetaGraphDef(exportDir, tag, nThreads);
0065 }
0066
0067 GraphDef* loadGraphDef(const std::string& pbFile) {
0068
0069 Status status;
0070
0071
0072 GraphDef* graphDef = new GraphDef();
0073 status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
0074
0075
0076 if (!status.ok()) {
0077 throw cms::Exception("InvalidGraphDef")
0078 << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
0079 }
0080
0081 return graphDef;
0082 }
0083
0084 Session* createSession(SessionOptions& sessionOptions) {
0085
0086 Status status;
0087
0088
0089 Session* session = nullptr;
0090 status = NewSession(sessionOptions, &session);
0091 if (!status.ok()) {
0092 throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
0093 }
0094
0095 return session;
0096 }
0097
0098 Session* createSession(int nThreads) {
0099
0100 SessionOptions sessionOptions;
0101 setThreading(sessionOptions, nThreads);
0102
0103 return createSession(sessionOptions);
0104 }
0105
0106 Session* createSession(const MetaGraphDef* metaGraphDef,
0107 const std::string& exportDir,
0108 SessionOptions& sessionOptions) {
0109
0110 if (metaGraphDef == nullptr) {
0111 throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
0112 }
0113
0114
0115 if (metaGraphDef->graph_def().node_size() <= 0) {
0116 throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
0117 }
0118
0119 Session* session = createSession(sessionOptions);
0120
0121
0122 Status status;
0123 status = session->Create(metaGraphDef->graph_def());
0124 if (!status.ok()) {
0125 throw cms::Exception("InvalidMetaGraphDef")
0126 << "error while attaching metaGraphDef to session: " << status.ToString();
0127 }
0128
0129
0130
0131 std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
0132 std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
0133 std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
0134 std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
0135 std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
0136
0137
0138 if (!Env::Default()->FileExists(indexFile).ok()) {
0139 return session;
0140 }
0141
0142
0143 Tensor varFileTensor(DT_STRING, TensorShape({}));
0144 varFileTensor.scalar<tensorflow::tstring>()() = varFile;
0145
0146
0147 status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
0148 if (!status.ok()) {
0149 throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
0150 }
0151
0152 return session;
0153 }
0154
0155 Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads) {
0156
0157 SessionOptions sessionOptions;
0158 setThreading(sessionOptions, nThreads);
0159
0160 return createSession(metaGraphDef, exportDir, sessionOptions);
0161 }
0162
0163 Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions) {
0164
0165 if (graphDef == nullptr) {
0166 throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
0167 }
0168
0169
0170 if (graphDef->node_size() <= 0) {
0171 throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
0172 }
0173
0174
0175 Session* session = createSession(sessionOptions);
0176
0177
0178 Status status;
0179 status = session->Create(*graphDef);
0180
0181
0182 if (!status.ok()) {
0183 throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
0184 }
0185
0186 return session;
0187 }
0188
0189 Session* createSession(const GraphDef* graphDef, int nThreads) {
0190
0191 SessionOptions sessionOptions;
0192 setThreading(sessionOptions, nThreads);
0193
0194 return createSession(graphDef, sessionOptions);
0195 }
0196
0197 bool closeSession(Session*& session) {
0198 if (session == nullptr) {
0199 return true;
0200 }
0201
0202
0203 Status status = session->Close();
0204 delete session;
0205
0206
0207 session = nullptr;
0208
0209 return status.ok();
0210 }
0211
0212 bool closeSession(const Session*& session) {
0213 auto s = const_cast<Session*>(session);
0214 bool state = closeSession(s);
0215
0216
0217 session = nullptr;
0218
0219 return state;
0220 }
0221
0222 void run(Session* session,
0223 const NamedTensorList& inputs,
0224 const std::vector<std::string>& outputNames,
0225 std::vector<Tensor>* outputs,
0226 const thread::ThreadPoolOptions& threadPoolOptions) {
0227 if (session == nullptr) {
0228 throw cms::Exception("InvalidSession") << "cannot run empty session";
0229 }
0230
0231
0232 RunOptions runOptions;
0233
0234
0235 Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
0236 if (!status.ok()) {
0237 throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
0238 }
0239 }
0240
0241 void run(Session* session,
0242 const NamedTensorList& inputs,
0243 const std::vector<std::string>& outputNames,
0244 std::vector<Tensor>* outputs,
0245 thread::ThreadPoolInterface* threadPool) {
0246
0247 thread::ThreadPoolOptions threadPoolOptions;
0248 threadPoolOptions.inter_op_threadpool = threadPool;
0249 threadPoolOptions.intra_op_threadpool = threadPool;
0250
0251
0252 run(session, inputs, outputNames, outputs, threadPoolOptions);
0253 }
0254
0255 void run(Session* session,
0256 const NamedTensorList& inputs,
0257 const std::vector<std::string>& outputNames,
0258 std::vector<Tensor>* outputs,
0259 const std::string& threadPoolName) {
0260
0261 if (threadPoolName == "no_threads") {
0262 run(session, inputs, outputNames, outputs, &NoThreadPool::instance());
0263 } else if (threadPoolName == "tbb") {
0264
0265 run(session, inputs, outputNames, outputs, &TBBThreadPool::instance());
0266 } else if (threadPoolName == "tensorflow") {
0267 run(session, inputs, outputNames, outputs, nullptr);
0268 } else {
0269 throw cms::Exception("UnknownThreadPool")
0270 << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
0271 }
0272 }
0273
0274 void run(Session* session,
0275 const std::vector<std::string>& outputNames,
0276 std::vector<Tensor>* outputs,
0277 const std::string& threadPoolName) {
0278 run(session, {}, outputNames, outputs, threadPoolName);
0279 }
0280
0281 void SessionCache::closeSession() {
0282
0283 Session* s = session.load();
0284 if (s != nullptr) {
0285 tensorflow::closeSession(s);
0286 session.store(nullptr);
0287 }
0288
0289
0290 if (graph.load() != nullptr) {
0291 delete graph.load();
0292 graph.store(nullptr);
0293 }
0294 }
0295
0296 }