File indexing completed on 2023-03-17 11:16:48
0001
0002
0003
0004
0005
0006
0007
0008
0009 #ifndef PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H
0010 #define PHYSICSTOOLS_TENSORFLOW_TBBTHREADPOOL_H
0011
0012 #include "FWCore/Utilities/interface/thread_safety_macros.h"
0013
0014 #include "tensorflow/core/lib/core/threadpool.h"
0015
0016 #include "oneapi/tbb/task_arena.h"
0017 #include "oneapi/tbb/task_group.h"
0018 #include "oneapi/tbb/global_control.h"
0019
0020 namespace tensorflow {
0021
0022 class TBBThreadPool : public tensorflow::thread::ThreadPoolInterface {
0023 public:
0024 static TBBThreadPool& instance(int nThreads = -1) {
0025 CMS_THREAD_SAFE static TBBThreadPool pool(nThreads);
0026 return pool;
0027 }
0028
0029 explicit TBBThreadPool(int nThreads = -1)
0030 : nThreads_(nThreads > 0 ? nThreads
0031 : tbb::global_control::active_value(tbb::global_control::max_allowed_parallelism)),
0032 numScheduleCalled_(0) {
0033
0034 }
0035
0036 void Schedule(std::function<void()> fn) override {
0037 numScheduleCalled_ += 1;
0038
0039
0040
0041 tbb::task_arena taskArena;
0042 tbb::task_group taskGroup;
0043
0044
0045 auto doneWithTaskGroup = [&taskArena, &taskGroup](void*) {
0046 taskArena.execute([&taskGroup]() { taskGroup.wait(); });
0047 };
0048 std::unique_ptr<tbb::task_group, decltype(doneWithTaskGroup)> taskGuard(&taskGroup, doneWithTaskGroup);
0049
0050
0051 taskArena.execute([&taskGroup, &fn] { taskGroup.run(fn); });
0052
0053
0054 taskGuard.reset();
0055 }
0056
0057 void ScheduleWithHint(std::function<void()> fn, int start, int end) override { Schedule(fn); }
0058
0059 void Cancel() override {}
0060
0061 int NumThreads() const override { return nThreads_; }
0062
0063 int CurrentThreadId() const override {
0064 static std::atomic<int> idCounter{0};
0065 thread_local const int id = idCounter++;
0066 return id;
0067 }
0068
0069 int GetNumScheduleCalled() { return numScheduleCalled_; }
0070
0071 private:
0072 const int nThreads_;
0073 std::atomic<int> numScheduleCalled_;
0074 };
0075
0076 }
0077
0078 #endif