Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-04-13 23:19:19

0001 /*
0002  * TensorFlow interface helpers.
0003  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0004  *
0005  * Author: Marcel Rieger
0006  */
0007 
0008 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0009 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0010 #include "FWCore/ServiceRegistry/interface/Service.h"
0011 #include "FWCore/Utilities/interface/ResourceInformation.h"
0012 
0013 namespace tensorflow {
0014 
0015   void Options::setThreading(int nThreads) {
0016     _nThreads = nThreads;
0017     // set number of threads used for intra and inter operation communication
0018     _options.config.set_intra_op_parallelism_threads(nThreads);
0019     _options.config.set_inter_op_parallelism_threads(nThreads);
0020   }
0021 
0022   void Options::setBackend(Backend backend) {
0023     /*
0024      * The TensorFlow backend configures the available devices using options provided in the sessionOptions proto.
0025      * // Options from https://github.com/tensorflow/tensorflow/blob/c53dab9fbc9de4ea8b1df59041a5ffd3987328c3/tensorflow/core/protobuf/config.proto
0026      *
0027      * If the device_count["GPU"] = 0 GPUs are not used. 
0028      * The visible_device_list configuration is used to map the `visible` devices (from CUDA_VISIBLE_DEVICES) to `virtual` devices.
0029      * If Backend::cpu is request, the GPU device is disallowed by device_count configuration.
0030      * If Backend::cuda is request:
0031      *  - if ResourceInformation shows an available Nvidia GPU device:
0032      *     the device is used with memory_growth configuration (not allocating all cuda memory at once).
0033      *  - if no device is present: an exception is raised.
0034      */
0035 
0036     edm::Service<edm::ResourceInformation> ri;
0037     if (backend == Backend::cpu) {
0038       // disable GPU usage
0039       (*_options.config.mutable_device_count())["GPU"] = 0;
0040       _options.config.mutable_gpu_options()->set_visible_device_list("");
0041     }
0042     // NVidia GPU
0043     else if (backend == Backend::cuda) {
0044       if (not ri->nvidiaDriverVersion().empty()) {
0045         // Take only the first GPU in the CUDA_VISIBLE_DEVICE list
0046         (*_options.config.mutable_device_count())["GPU"] = 1;
0047         _options.config.mutable_gpu_options()->set_visible_device_list("0");
0048         // Do not allocate all the memory on the GPU at the beginning.
0049         _options.config.mutable_gpu_options()->set_allow_growth(true);
0050       } else {
0051         edm::Exception ex(edm::errors::UnavailableAccelerator);
0052         ex << "Cuda backend requested, but no NVIDIA GPU available in the job";
0053         ex.addContext("Calling tensorflow::setBackend()");
0054         throw ex;
0055       }
0056     }
0057     // ROCm and Intel GPU are still not supported
0058     else if ((backend == Backend::rocm) || (backend == Backend::intel)) {
0059       edm::Exception ex(edm::errors::UnavailableAccelerator);
0060       ex << "ROCm/Intel GPU backend requested, but TF is not compiled yet for this platform";
0061       ex.addContext("Calling tensorflow::setBackend()");
0062       throw ex;
0063     }
0064     // Get NVidia GPU if possible or fallback to CPU
0065     else if (backend == Backend::best) {
0066       // Check if a Nvidia GPU is availabl
0067       if (not ri->nvidiaDriverVersion().empty()) {
0068         // Take only the first GPU in the CUDA_VISIBLE_DEVICE list
0069         (*_options.config.mutable_device_count())["GPU"] = 1;
0070         _options.config.mutable_gpu_options()->set_visible_device_list("0");
0071         // Do not allocate all the memory on the GPU at the beginning.
0072         _options.config.mutable_gpu_options()->set_allow_growth(true);
0073       } else {
0074         // Just CPU support
0075         (*_options.config.mutable_device_count())["GPU"] = 0;
0076         _options.config.mutable_gpu_options()->set_visible_device_list("");
0077       }
0078     }
0079   }
0080 
0081   void setLogging(const std::string& level) {
0082     /*
0083      * 0 = all messages are logged (default behavior)
0084      * 1 = INFO messages are not printed
0085      * 2 = INFO and WARNING messages are not printed
0086      * 3 = INFO, WARNING, and ERROR messages are not printed
0087      */
0088     setenv("TF_CPP_MIN_LOG_LEVEL", level.c_str(), 0);
0089   }
0090 
0091   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag) {
0092     Options default_options{};
0093     return loadMetaGraphDef(exportDir, tag, default_options);
0094   }
0095 
0096   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, Options& options) {
0097     // objects to load the graph
0098     Status status;
0099     RunOptions runOptions;
0100     SavedModelBundle bundle;
0101 
0102     // load the model
0103     status = LoadSavedModel(options.getSessionOptions(), runOptions, exportDir, {tag}, &bundle);
0104     if (!status.ok()) {
0105       throw cms::Exception("InvalidMetaGraphDef")
0106           << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
0107     }
0108 
0109     // return a copy of the graph
0110     return new MetaGraphDef(bundle.meta_graph_def);
0111   }
0112 
0113   MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, Options& options) {
0114     edm::LogInfo("PhysicsTools/TensorFlow")
0115         << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
0116 
0117     return loadMetaGraphDef(exportDir, tag, options);
0118   }
0119 
0120   GraphDef* loadGraphDef(const std::string& pbFile) {
0121     // objects to load the graph
0122     Status status;
0123 
0124     // load it
0125     GraphDef* graphDef = new GraphDef();
0126     status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
0127 
0128     // check for success
0129     if (!status.ok()) {
0130       throw cms::Exception("InvalidGraphDef")
0131           << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
0132     }
0133 
0134     return graphDef;
0135   }
0136 
0137   Session* createSession() {
0138     Options default_options{};
0139     return createSession(default_options);
0140   }
0141 
0142   Session* createSession(Options& options) {
0143     // objects to create the session
0144     Status status;
0145 
0146     // create a new, empty session
0147     Session* session = nullptr;
0148     status = NewSession(options.getSessionOptions(), &session);
0149     if (!status.ok()) {
0150       throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
0151     }
0152 
0153     return session;
0154   }
0155 
0156   Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, Options& options) {
0157     // check for valid pointer
0158     if (metaGraphDef == nullptr) {
0159       throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
0160     }
0161 
0162     // check that the graph has nodes
0163     if (metaGraphDef->graph_def().node_size() <= 0) {
0164       throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
0165     }
0166 
0167     Session* session = createSession(options);
0168 
0169     // add the graph def from the meta graph
0170     Status status;
0171     status = session->Create(metaGraphDef->graph_def());
0172     if (!status.ok()) {
0173       throw cms::Exception("InvalidMetaGraphDef")
0174           << "error while attaching metaGraphDef to session: " << status.ToString();
0175     }
0176 
0177     // restore variables using the variable and index files in the export directory
0178     // first, find names and paths
0179     std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
0180     std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
0181     std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
0182     std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
0183     std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
0184 
0185     // when the index file is missing, there's nothing to do
0186     if (!Env::Default()->FileExists(indexFile).ok()) {
0187       return session;
0188     }
0189 
0190     // create a tensor to store the variable file
0191     Tensor varFileTensor(DT_STRING, TensorShape({}));
0192     varFileTensor.scalar<tensorflow::tstring>()() = varFile;
0193 
0194     // run the restore op
0195     status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
0196     if (!status.ok()) {
0197       throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
0198     }
0199 
0200     return session;
0201   }
0202 
0203   Session* createSession(const GraphDef* graphDef) {
0204     Options default_options{};
0205     return createSession(graphDef, default_options);
0206   }
0207 
0208   Session* createSession(const GraphDef* graphDef, Options& options) {
0209     // check for valid pointer
0210     if (graphDef == nullptr) {
0211       throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
0212     }
0213 
0214     // check that the graph has nodes
0215     if (graphDef->node_size() <= 0) {
0216       throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
0217     }
0218 
0219     // create a new, empty session
0220     Session* session = createSession(options);
0221 
0222     // add the graph def
0223     Status status;
0224     status = session->Create(*graphDef);
0225 
0226     // check for success
0227     if (!status.ok()) {
0228       throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
0229     }
0230 
0231     return session;
0232   }
0233 
0234   bool closeSession(Session*& session) {
0235     if (session == nullptr) {
0236       return true;
0237     }
0238 
0239     // close and delete the session
0240     Status status = session->Close();
0241     delete session;
0242 
0243     // reset the pointer
0244     session = nullptr;
0245 
0246     return status.ok();
0247   }
0248 
0249   bool closeSession(const Session*& session) {
0250     auto s = const_cast<Session*>(session);
0251     bool state = closeSession(s);
0252 
0253     // reset the pointer
0254     session = nullptr;
0255 
0256     return state;
0257   }
0258 
0259   void run(Session* session,
0260            const NamedTensorList& inputs,
0261            const std::vector<std::string>& outputNames,
0262            std::vector<Tensor>* outputs,
0263            const thread::ThreadPoolOptions& threadPoolOptions) {
0264     if (session == nullptr) {
0265       throw cms::Exception("InvalidSession") << "cannot run empty session";
0266     }
0267 
0268     // create empty run options
0269     RunOptions runOptions;
0270 
0271     // run and check the status
0272     Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
0273     if (!status.ok()) {
0274       throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
0275     }
0276   }
0277 
0278   void run(Session* session,
0279            const NamedTensorList& inputs,
0280            const std::vector<std::string>& outputNames,
0281            std::vector<Tensor>* outputs,
0282            thread::ThreadPoolInterface* threadPool) {
0283     // create thread pool options
0284     thread::ThreadPoolOptions threadPoolOptions;
0285     threadPoolOptions.inter_op_threadpool = threadPool;
0286     threadPoolOptions.intra_op_threadpool = threadPool;
0287 
0288     // run
0289     run(session, inputs, outputNames, outputs, threadPoolOptions);
0290   }
0291 
0292   void run(Session* session,
0293            const NamedTensorList& inputs,
0294            const std::vector<std::string>& outputNames,
0295            std::vector<Tensor>* outputs,
0296            const std::string& threadPoolName) {
0297     // lookup the thread pool and forward the call accordingly
0298     if (threadPoolName == "no_threads") {
0299       run(session, inputs, outputNames, outputs, &NoThreadPool::instance());
0300     } else if (threadPoolName == "tbb") {
0301       // the TBBTreadPool singleton should be already initialized before with a number of threads
0302       run(session, inputs, outputNames, outputs, &TBBThreadPool::instance());
0303     } else if (threadPoolName == "tensorflow") {
0304       run(session, inputs, outputNames, outputs, nullptr);
0305     } else {
0306       throw cms::Exception("UnknownThreadPool")
0307           << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
0308     }
0309   }
0310 
0311   void run(Session* session,
0312            const std::vector<std::string>& outputNames,
0313            std::vector<Tensor>* outputs,
0314            const std::string& threadPoolName) {
0315     run(session, {}, outputNames, outputs, threadPoolName);
0316   }
0317 
0318   void SessionCache::closeSession() {
0319     // delete the session if set
0320     Session* s = session.load();
0321     if (s != nullptr) {
0322       tensorflow::closeSession(s);
0323       session.store(nullptr);
0324     }
0325 
0326     // delete the graph if set
0327     if (graph.load() != nullptr) {
0328       delete graph.load();
0329       graph.store(nullptr);
0330     }
0331   }
0332 
0333 }  // namespace tensorflow