Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-02-12 04:02:38

0001 /*
0002  * TensorFlow interface helpers.
0003  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0004  *
0005  * Author: Marcel Rieger
0006  */
0007 
0008 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0009 #include "FWCore/AbstractServices/interface/ResourceInformation.h"
0010 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0011 #include "FWCore/ServiceRegistry/interface/Service.h"
0012 
0013 namespace tensorflow {
0014 
0015   void Options::setThreading(int nThreads) {
0016     _nThreads = nThreads;
0017     // set number of threads used for intra and inter operation communication
0018     _options.config.set_intra_op_parallelism_threads(nThreads);
0019     _options.config.set_inter_op_parallelism_threads(nThreads);
0020   }
0021 
0022   void Options::setBackend(Backend backend) {
0023     /*
0024      * The TensorFlow backend configures the available devices using options provided in the sessionOptions proto.
0025      * // Options from https://github.com/tensorflow/tensorflow/blob/c53dab9fbc9de4ea8b1df59041a5ffd3987328c3/tensorflow/core/protobuf/config.proto
0026      *
0027      * If the device_count["GPU"] = 0 GPUs are not used. 
0028      * The visible_device_list configuration is used to map the `visible` devices (from CUDA_VISIBLE_DEVICES) to `virtual` devices.
0029      * If Backend::cpu is request, the GPU device is disallowed by device_count configuration.
0030      * If Backend::cuda is request:
0031      *  - if ResourceInformation shows an available Nvidia GPU device:
0032      *     the device is used with memory_growth configuration (not allocating all cuda memory at once).
0033      *  - if no device is present: an exception is raised.
0034      */
0035 
0036     edm::Service<edm::ResourceInformation> ri;
0037     if (backend == Backend::cpu) {
0038       // disable GPU usage
0039       (*_options.config.mutable_device_count())["GPU"] = 0;
0040       _options.config.mutable_gpu_options()->set_visible_device_list("");
0041     }
0042     // NVidia GPU
0043     else if (backend == Backend::cuda) {
0044       if (ri->hasGpuNvidia()) {
0045         // Check if one GPU device is visible to TF
0046         // If not, an exception is raised --> this can happen in case of driver version mismatch
0047         // or missing CUDA support in TF compilation
0048         if ((*_options.config.mutable_device_count())["GPU"] == 0) {
0049           edm::Exception ex(edm::errors::UnavailableAccelerator);
0050           ex << "Cuda backend requested, NVIDIA GPU visible to cmssw, but not visible to TensorFlow in the job";
0051           ex.addContext("Calling tensorflow::setBackend()");
0052           throw ex;
0053         }
0054         // Take only the first GPU in the CUDA_VISIBLE_DEVICE list
0055         (*_options.config.mutable_device_count())["GPU"] = 1;
0056         _options.config.mutable_gpu_options()->set_visible_device_list("0");
0057         // Do not allocate all the memory on the GPU at the beginning.
0058         _options.config.mutable_gpu_options()->set_allow_growth(true);
0059       } else {
0060         edm::Exception ex(edm::errors::UnavailableAccelerator);
0061         ex << "Cuda backend requested, but no NVIDIA GPU available in the job";
0062         ex.addContext("Calling tensorflow::setBackend()");
0063         throw ex;
0064       }
0065     }
0066     // ROCm and Intel GPU are still not supported
0067     else if ((backend == Backend::rocm) || (backend == Backend::intel)) {
0068       edm::Exception ex(edm::errors::UnavailableAccelerator);
0069       ex << "ROCm/Intel GPU backend requested, but TF is not compiled yet for this platform";
0070       ex.addContext("Calling tensorflow::setBackend()");
0071       throw ex;
0072     }
0073     // Get NVidia GPU if possible or fallback to CPU
0074     else if (backend == Backend::best) {
0075       // Check if a Nvidia GPU is availabl
0076       if (ri->hasGpuNvidia()) {
0077         // Take only the first GPU in the CUDA_VISIBLE_DEVICE list
0078         (*_options.config.mutable_device_count())["GPU"] = 1;
0079         _options.config.mutable_gpu_options()->set_visible_device_list("0");
0080         // Do not allocate all the memory on the GPU at the beginning.
0081         _options.config.mutable_gpu_options()->set_allow_growth(true);
0082       } else {
0083         // Just CPU support
0084         (*_options.config.mutable_device_count())["GPU"] = 0;
0085         _options.config.mutable_gpu_options()->set_visible_device_list("");
0086       }
0087     }
0088   }
0089 
0090   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag) {
0091     Options default_options{};
0092     return loadMetaGraphDef(exportDir, tag, default_options);
0093   }
0094 
0095   MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, Options& options) {
0096     // objects to load the graph
0097     Status status;
0098     RunOptions runOptions;
0099     SavedModelBundle bundle;
0100 
0101     // load the model
0102     status = LoadSavedModel(options.getSessionOptions(), runOptions, exportDir, {tag}, &bundle);
0103     if (!status.ok()) {
0104       throw cms::Exception("InvalidMetaGraphDef")
0105           << "error while loading metaGraphDef from '" << exportDir << "': " << status.ToString();
0106     }
0107 
0108     // return a copy of the graph
0109     return new MetaGraphDef(bundle.meta_graph_def);
0110   }
0111 
0112   MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, Options& options) {
0113     edm::LogInfo("PhysicsTools/TensorFlow")
0114         << "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
0115 
0116     return loadMetaGraphDef(exportDir, tag, options);
0117   }
0118 
0119   GraphDef* loadGraphDef(const std::string& pbFile) {
0120     // objects to load the graph
0121     Status status;
0122 
0123     // load it
0124     GraphDef* graphDef = new GraphDef();
0125     status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
0126 
0127     // check for success
0128     if (!status.ok()) {
0129       throw cms::Exception("InvalidGraphDef")
0130           << "error while loading graphDef from '" << pbFile << "': " << status.ToString();
0131     }
0132 
0133     return graphDef;
0134   }
0135 
0136   Session* createSession() {
0137     Options default_options{};
0138     return createSession(default_options);
0139   }
0140 
0141   Session* createSession(Options& options) {
0142     // objects to create the session
0143     Status status;
0144 
0145     // create a new, empty session
0146     Session* session = nullptr;
0147     status = NewSession(options.getSessionOptions(), &session);
0148     if (!status.ok()) {
0149       throw cms::Exception("InvalidSession") << "error while creating session: " << status.ToString();
0150     }
0151 
0152     return session;
0153   }
0154 
0155   Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, Options& options) {
0156     // check for valid pointer
0157     if (metaGraphDef == nullptr) {
0158       throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: metaGraphDef is nullptr";
0159     }
0160 
0161     // check that the graph has nodes
0162     if (metaGraphDef->graph_def().node_size() <= 0) {
0163       throw cms::Exception("InvalidMetaGraphDef") << "error while creating session: graphDef has no nodes";
0164     }
0165 
0166     Session* session = createSession(options);
0167 
0168     // add the graph def from the meta graph
0169     Status status;
0170     status = session->Create(metaGraphDef->graph_def());
0171     if (!status.ok()) {
0172       throw cms::Exception("InvalidMetaGraphDef")
0173           << "error while attaching metaGraphDef to session: " << status.ToString();
0174     }
0175 
0176     // restore variables using the variable and index files in the export directory
0177     // first, find names and paths
0178     std::string varFileTensorName = metaGraphDef->saver_def().filename_tensor_name();
0179     std::string restoreOpName = metaGraphDef->saver_def().restore_op_name();
0180     std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
0181     std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
0182     std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
0183 
0184     // when the index file is missing, there's nothing to do
0185     if (!Env::Default()->FileExists(indexFile).ok()) {
0186       return session;
0187     }
0188 
0189     // create a tensor to store the variable file
0190     Tensor varFileTensor(DT_STRING, TensorShape({}));
0191     varFileTensor.scalar<tensorflow::tstring>()() = varFile;
0192 
0193     // run the restore op
0194     status = session->Run({{varFileTensorName, varFileTensor}}, {}, {restoreOpName}, nullptr);
0195     if (!status.ok()) {
0196       throw cms::Exception("InvalidSession") << "error while restoring variables in session: " << status.ToString();
0197     }
0198 
0199     return session;
0200   }
0201 
0202   Session* createSession(const GraphDef* graphDef) {
0203     Options default_options{};
0204     return createSession(graphDef, default_options);
0205   }
0206 
0207   Session* createSession(const GraphDef* graphDef, Options& options) {
0208     // check for valid pointer
0209     if (graphDef == nullptr) {
0210       throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
0211     }
0212 
0213     // check that the graph has nodes
0214     if (graphDef->node_size() <= 0) {
0215       throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef has no nodes";
0216     }
0217 
0218     // create a new, empty session
0219     Session* session = createSession(options);
0220 
0221     // add the graph def
0222     Status status;
0223     status = session->Create(*graphDef);
0224 
0225     // check for success
0226     if (!status.ok()) {
0227       throw cms::Exception("InvalidSession") << "error while attaching graphDef to session: " << status.ToString();
0228     }
0229 
0230     return session;
0231   }
0232 
0233   bool closeSession(Session*& session) {
0234     if (session == nullptr) {
0235       return true;
0236     }
0237 
0238     // close and delete the session
0239     Status status = session->Close();
0240     delete session;
0241 
0242     // reset the pointer
0243     session = nullptr;
0244 
0245     return status.ok();
0246   }
0247 
0248   bool closeSession(const Session*& session) {
0249     auto s = const_cast<Session*>(session);
0250     bool state = closeSession(s);
0251 
0252     // reset the pointer
0253     session = nullptr;
0254 
0255     return state;
0256   }
0257 
0258   bool checkEmptyInputs(const NamedTensorList& inputs) {
0259     // check for empty tensors in the inputs
0260     bool isEmpty = false;
0261     for (const auto& input : inputs) {
0262       // Checking using the shape
0263       if (input.second.shape().num_elements() == 0) {
0264         isEmpty = true;
0265         break;
0266       }
0267     }
0268     return isEmpty;
0269   }
0270 
0271   void run(Session* session,
0272            const NamedTensorList& inputs,
0273            const std::vector<std::string>& outputNames,
0274            std::vector<Tensor>* outputs,
0275            const thread::ThreadPoolOptions& threadPoolOptions) {
0276     if (session == nullptr) {
0277       throw cms::Exception("InvalidSession") << "cannot run empty session";
0278     }
0279 
0280     // create empty run options
0281     RunOptions runOptions;
0282 
0283     // Check if the inputs are empty
0284     if (checkEmptyInputs(inputs))
0285       return;
0286 
0287     // run and check the status
0288     Status status = session->Run(runOptions, inputs, outputNames, {}, outputs, nullptr, threadPoolOptions);
0289     if (!status.ok()) {
0290       throw cms::Exception("InvalidRun") << "error while running session: " << status.ToString();
0291     }
0292   }
0293 
0294   void run(Session* session,
0295            const NamedTensorList& inputs,
0296            const std::vector<std::string>& outputNames,
0297            std::vector<Tensor>* outputs,
0298            thread::ThreadPoolInterface* threadPool) {
0299     // create thread pool options
0300     thread::ThreadPoolOptions threadPoolOptions;
0301     threadPoolOptions.inter_op_threadpool = threadPool;
0302     threadPoolOptions.intra_op_threadpool = threadPool;
0303 
0304     // run
0305     run(session, inputs, outputNames, outputs, threadPoolOptions);
0306   }
0307 
0308   void run(Session* session,
0309            const NamedTensorList& inputs,
0310            const std::vector<std::string>& outputNames,
0311            std::vector<Tensor>* outputs,
0312            const std::string& threadPoolName) {
0313     // lookup the thread pool and forward the call accordingly
0314     if (threadPoolName == "no_threads") {
0315       run(session, inputs, outputNames, outputs, &NoThreadPool::instance());
0316     } else if (threadPoolName == "tbb") {
0317       // the TBBTreadPool singleton should be already initialized before with a number of threads
0318       run(session, inputs, outputNames, outputs, &TBBThreadPool::instance());
0319     } else if (threadPoolName == "tensorflow") {
0320       run(session, inputs, outputNames, outputs, nullptr);
0321     } else {
0322       throw cms::Exception("UnknownThreadPool")
0323           << "thread pool implementation'" << threadPoolName << "' unknown, use 'no_threads', 'tbb', or 'tensorflow'";
0324     }
0325   }
0326 
0327   void run(Session* session,
0328            const std::vector<std::string>& outputNames,
0329            std::vector<Tensor>* outputs,
0330            const std::string& threadPoolName) {
0331     run(session, {}, outputNames, outputs, threadPoolName);
0332   }
0333 
0334   void SessionCache::closeSession() {
0335     // delete the session if set
0336     Session* s = session.load();
0337     if (s != nullptr) {
0338       tensorflow::closeSession(s);
0339       session.store(nullptr);
0340     }
0341 
0342     // delete the graph if set
0343     if (graph.load() != nullptr) {
0344       delete graph.load();
0345       graph.store(nullptr);
0346     }
0347   }
0348 
0349 }  // namespace tensorflow