Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:15

0001 /*
0002  * Custom TensorFlow thread pool implementation that schedules tasks in TBB.
0003  * Based on TensorFlow 2.1.
0004  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0005  *
0006  * Author: Marcel Rieger
0007  */
0008 
0009 #ifndef PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H
0010 #define PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H
0011 
0012 #include "FWCore/Utilities/interface/thread_safety_macros.h"
0013 
0014 #include "tensorflow/core/lib/core/threadpool.h"
0015 
0016 #include "oneapi/tbb/task_arena.h"
0017 #include "oneapi/tbb/task_group.h"
0018 #include "oneapi/tbb/global_control.h"
0019 
0020 namespace tensorflow {
0021 
0022   class TBBThreadPool : public tensorflow::thread::ThreadPoolInterface {
0023   public:
0024     static TBBThreadPool& instance(int nThreads = -1) {
0025       CMS_THREAD_SAFE static TBBThreadPool pool(nThreads);
0026       return pool;
0027     }
0028 
0029     explicit TBBThreadPool(int nThreads = -1)
0030         : nThreads_(nThreads > 0 ? nThreads
0031                                  : tbb::global_control::active_value(tbb::global_control::max_allowed_parallelism)),
0032           numScheduleCalled_(0) {
0033       // when nThreads is zero or smaller, use the default value determined by tbb
0034     }
0035 
0036     void Schedule(std::function<void()> fn) override {
0037       numScheduleCalled_ += 1;
0038 
0039       // use a task arena to avoid having unrelated tasks start
0040       // running on this thread, which could potentially start deadlocks
0041       tbb::task_arena taskArena;
0042       tbb::task_group taskGroup;
0043 
0044       // we are required to always call wait before destructor
0045       auto doneWithTaskGroup = [&taskArena, &taskGroup](void*) {
0046         taskArena.execute([&taskGroup]() { taskGroup.wait(); });
0047       };
0048       std::unique_ptr<tbb::task_group, decltype(doneWithTaskGroup)> taskGuard(&taskGroup, doneWithTaskGroup);
0049 
0050       // schedule the task
0051       taskArena.execute([&taskGroup, &fn] { taskGroup.run(fn); });
0052 
0053       // reset the task guard which will call wait
0054       taskGuard.reset();
0055     }
0056 
0057     void ScheduleWithHint(std::function<void()> fn, int start, int end) override { Schedule(fn); }
0058 
0059     void Cancel() override {}
0060 
0061     int NumThreads() const override { return nThreads_; }
0062 
0063     int CurrentThreadId() const override {
0064       static std::atomic<int> idCounter{0};
0065       thread_local const int id = idCounter++;
0066       return id;
0067     }
0068 
0069     int GetNumScheduleCalled() { return numScheduleCalled_; }
0070 
0071   private:
0072     const int nThreads_;
0073     std::atomic<int> numScheduleCalled_;
0074   };
0075 
0076 }  // namespace tensorflow
0077 
0078 #endif  // PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H