Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-06-22 02:23:39

0001 #include "catch.hpp"
0002 
0003 #include "oneapi/tbb/global_control.h"
0004 
0005 #include "FWCore/Concurrency/interface/chain_first.h"
0006 #include "FWCore/Concurrency/interface/FinalWaitingTask.h"
0007 #include "FWCore/Concurrency/interface/WaitingThreadPool.h"
0008 #include "FWCore/Concurrency/interface/hardware_pause.h"
0009 
0010 namespace {
0011   constexpr char const* errorContext() { return "WaitingThreadPool test"; }
0012 }  // namespace
0013 
0014 TEST_CASE("Test WaitingThreadPool", "[edm::WaitingThreadPool") {
0015   // Using parallelism 2 here because otherwise the
0016   // tbb::task_arena::enqueue() in WaitingTaskWithArenaHolder will
0017   // start a new TBB thread that "inherits" the name from the
0018   // WaitingThreadPool thread.
0019   oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, 2);
0020   edm::WaitingThreadPool pool;
0021 
0022   SECTION("One async call") {
0023     std::atomic<int> count{0};
0024 
0025     oneapi::tbb::task_group group;
0026     edm::FinalWaitingTask waitTask{group};
0027     {
0028       using namespace edm::waiting_task::chain;
0029       auto h = first([&pool, &count](edm::WaitingTaskHolder h) {
0030                  edm::WaitingTaskWithArenaHolder h2(std::move(h));
0031                  pool.runAsync(
0032                      std::move(h2), [&count]() { ++count; }, errorContext);
0033                }) |
0034                lastTask(edm::WaitingTaskHolder(group, &waitTask));
0035       h.doneWaiting(std::exception_ptr());
0036     }
0037     waitTask.waitNoThrow();
0038     REQUIRE(count.load() == 1);
0039     REQUIRE(waitTask.done());
0040     REQUIRE(not waitTask.exceptionPtr());
0041   }
0042 
0043   SECTION("Two async calls") {
0044     std::atomic<int> count{0};
0045     std::atomic<bool> mayContinue{false};
0046 
0047     oneapi::tbb::task_group group;
0048     edm::FinalWaitingTask waitTask{group};
0049 
0050     {
0051       using namespace edm::waiting_task::chain;
0052       auto h = first([&pool, &count, &mayContinue](edm::WaitingTaskHolder h) {
0053                  edm::WaitingTaskWithArenaHolder h2(std::move(h));
0054                  pool.runAsync(
0055                      h2,
0056                      [&count, &mayContinue]() {
0057                        while (not mayContinue) {
0058                          hardware_pause();
0059                        }
0060                        using namespace std::chrono_literals;
0061                        std::this_thread::sleep_for(10ms);
0062                        ++count;
0063                      },
0064                      errorContext);
0065                  pool.runAsync(
0066                      h2, [&count]() { ++count; }, errorContext);
0067                }) |
0068                lastTask(edm::WaitingTaskHolder(group, &waitTask));
0069       h.doneWaiting(std::exception_ptr());
0070     }
0071     mayContinue = true;
0072     waitTask.waitNoThrow();
0073     REQUIRE(count.load() == 2);
0074     REQUIRE(waitTask.done());
0075     REQUIRE(not waitTask.exceptionPtr());
0076   }
0077 
0078   SECTION("Concurrent async calls") {
0079     std::atomic<int> count{0};
0080     std::atomic<int> mayContinue{0};
0081 
0082     oneapi::tbb::task_group group;
0083     edm::FinalWaitingTask waitTask{group};
0084 
0085     {
0086       using namespace edm::waiting_task::chain;
0087       auto h1 = first([&pool, &count, &mayContinue](edm::WaitingTaskHolder h) {
0088                   edm::WaitingTaskWithArenaHolder h2(std::move(h));
0089                   ++mayContinue;
0090                   while (mayContinue != 2) {
0091                     hardware_pause();
0092                   }
0093                   pool.runAsync(
0094                       h2, [&count]() { ++count; }, errorContext);
0095                 }) |
0096                 lastTask(edm::WaitingTaskHolder(group, &waitTask));
0097 
0098       auto h2 = first([&pool, &count, &mayContinue](edm::WaitingTaskHolder h) {
0099                   edm::WaitingTaskWithArenaHolder h2(std::move(h));
0100                   ++mayContinue;
0101                   while (mayContinue != 2) {
0102                     hardware_pause();
0103                   }
0104                   pool.runAsync(
0105                       h2, [&count]() { ++count; }, errorContext);
0106                 }) |
0107                 lastTask(edm::WaitingTaskHolder(group, &waitTask));
0108       h2.doneWaiting(std::exception_ptr());
0109       h1.doneWaiting(std::exception_ptr());
0110     }
0111     waitTask.waitNoThrow();
0112     REQUIRE(count.load() == 2);
0113     REQUIRE(waitTask.done());
0114     REQUIRE(not waitTask.exceptionPtr());
0115   }
0116 
0117   SECTION("Exceptions") {
0118     SECTION("One async call") {
0119       std::atomic<int> count{0};
0120 
0121       oneapi::tbb::task_group group;
0122       edm::FinalWaitingTask waitTask{group};
0123 
0124       {
0125         using namespace edm::waiting_task::chain;
0126         auto h = first([&pool](edm::WaitingTaskHolder h) {
0127                    edm::WaitingTaskWithArenaHolder h2(std::move(h));
0128                    pool.runAsync(
0129                        std::move(h2), []() { throw std::runtime_error("error"); }, errorContext);
0130                  }) |
0131                  lastTask(edm::WaitingTaskHolder(group, &waitTask));
0132         h.doneWaiting(std::exception_ptr());
0133       }
0134       REQUIRE_THROWS_WITH(
0135           waitTask.wait(),
0136           Catch::Contains("error") and Catch::Contains("StdException") and Catch::Contains("WaitingThreadPool test"));
0137       REQUIRE(count.load() == 0);
0138     }
0139 
0140     SECTION("Two async calls") {
0141       std::atomic<int> count{0};
0142 
0143       oneapi::tbb::task_group group;
0144       edm::FinalWaitingTask waitTask{group};
0145 
0146       {
0147         using namespace edm::waiting_task::chain;
0148         auto h = first([&pool, &count](edm::WaitingTaskHolder h) {
0149                    edm::WaitingTaskWithArenaHolder h2(std::move(h));
0150                    pool.runAsync(
0151                        h2,
0152                        [&count]() {
0153                          if (count.fetch_add(1) == 0) {
0154                            throw cms::Exception("error 1");
0155                          }
0156                          ++count;
0157                        },
0158                        errorContext);
0159                    pool.runAsync(
0160                        h2,
0161                        [&count]() {
0162                          if (count.fetch_add(1) == 0) {
0163                            throw cms::Exception("error 2");
0164                          }
0165                          ++count;
0166                        },
0167                        errorContext);
0168                  }) |
0169                  lastTask(edm::WaitingTaskHolder(group, &waitTask));
0170         h.doneWaiting(std::exception_ptr());
0171       }
0172       REQUIRE_THROWS_AS(waitTask.wait(), cms::Exception);
0173       REQUIRE(count.load() == 3);
0174     }
0175 
0176     SECTION("Two exceptions") {
0177       std::atomic<int> count{0};
0178 
0179       oneapi::tbb::task_group group;
0180       edm::FinalWaitingTask waitTask{group};
0181 
0182       {
0183         using namespace edm::waiting_task::chain;
0184         auto h = first([&pool, &count](edm::WaitingTaskHolder h) {
0185                    edm::WaitingTaskWithArenaHolder h2(std::move(h));
0186                    pool.runAsync(
0187                        h2,
0188                        [&count]() {
0189                          ++count;
0190                          throw cms::Exception("error 1");
0191                        },
0192                        errorContext);
0193                    pool.runAsync(
0194                        h2,
0195                        [&count]() {
0196                          ++count;
0197                          throw cms::Exception("error 2");
0198                        },
0199                        errorContext);
0200                  }) |
0201                  lastTask(edm::WaitingTaskHolder(group, &waitTask));
0202         h.doneWaiting(std::exception_ptr());
0203       }
0204       REQUIRE_THROWS_AS(waitTask.wait(), cms::Exception);
0205       REQUIRE(count.load() == 2);
0206     }
0207 
0208     SECTION("Concurrent exceptions") {
0209       std::atomic<int> count{0};
0210 
0211       oneapi::tbb::task_group group;
0212       edm::FinalWaitingTask waitTask{group};
0213 
0214       {
0215         using namespace edm::waiting_task::chain;
0216         auto h1 = first([&pool, &count](edm::WaitingTaskHolder h) {
0217                     edm::WaitingTaskWithArenaHolder h2(std::move(h));
0218                     pool.runAsync(
0219                         h2,
0220                         [&count]() {
0221                           ++count;
0222                           throw cms::Exception("error 1");
0223                         },
0224                         errorContext);
0225                   }) |
0226                   lastTask(edm::WaitingTaskHolder(group, &waitTask));
0227 
0228         auto h2 = first([&pool, &count](edm::WaitingTaskHolder h) {
0229                     edm::WaitingTaskWithArenaHolder h2(std::move(h));
0230                     pool.runAsync(
0231                         h2,
0232                         [&count]() {
0233                           ++count;
0234                           throw cms::Exception("error 2");
0235                         },
0236                         errorContext);
0237                   }) |
0238                   lastTask(edm::WaitingTaskHolder(group, &waitTask));
0239         h2.doneWaiting(std::exception_ptr());
0240         h1.doneWaiting(std::exception_ptr());
0241       }
0242       REQUIRE_THROWS_AS(waitTask.wait(), cms::Exception);
0243       REQUIRE(count.load() == 2);
0244     }
0245 
0246     SECTION("Concurrent exception and success") {
0247       std::atomic<int> count{0};
0248 
0249       oneapi::tbb::task_group group;
0250       edm::FinalWaitingTask waitTask{group};
0251 
0252       {
0253         using namespace edm::waiting_task::chain;
0254         auto h1 = first([&pool, &count](edm::WaitingTaskHolder h) {
0255                     edm::WaitingTaskWithArenaHolder h2(std::move(h));
0256                     pool.runAsync(
0257                         h2,
0258                         [&count]() {
0259                           ++count;
0260                           throw cms::Exception("error 1");
0261                         },
0262                         errorContext);
0263                   }) |
0264                   lastTask(edm::WaitingTaskHolder(group, &waitTask));
0265 
0266         auto h2 = first([&pool, &count](edm::WaitingTaskHolder h) {
0267                     edm::WaitingTaskWithArenaHolder h2(std::move(h));
0268                     pool.runAsync(
0269                         h2, [&count]() { ++count; }, errorContext);
0270                   }) |
0271                   lastTask(edm::WaitingTaskHolder(group, &waitTask));
0272         h2.doneWaiting(std::exception_ptr());
0273         h1.doneWaiting(std::exception_ptr());
0274       }
0275       REQUIRE_THROWS_AS(waitTask.wait(), cms::Exception);
0276       REQUIRE(count.load() == 2);
0277     }
0278   }
0279 }