Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-09-07 04:36:21

0001 #include "catch.hpp"
0002 
0003 #include "oneapi/tbb/global_control.h"
0004 
0005 #include "FWCore/Concurrency/interface/chain_first.h"
0006 #include "FWCore/Concurrency/interface/FinalWaitingTask.h"
0007 #include "FWCore/Concurrency/interface/WaitingThreadPool.h"
0008 #include "FWCore/Concurrency/interface/hardware_pause.h"
0009 
0010 namespace {
0011   constexpr char const* errorContext() { return "WaitingThreadPool test"; }
0012 }  // namespace
0013 
0014 TEST_CASE("Test WaitingThreadPool", "[edm::WaitingThreadPool") {
0015   // Using parallelism 2 here because otherwise the
0016   // tbb::task_arena::enqueue() in WaitingTaskWithArenaHolder will
0017   // start a new TBB thread that "inherits" the name from the
0018   // WaitingThreadPool thread.
0019   oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, 2);
0020   edm::WaitingThreadPool pool;
0021 
0022   SECTION("One async call") {
0023     std::atomic<int> count{0};
0024 
0025     oneapi::tbb::task_group group;
0026     edm::FinalWaitingTask waitTask{group};
0027     {
0028       using namespace edm::waiting_task::chain;
0029       auto h = first([&pool, &count](edm::WaitingTaskHolder h) {
0030                  edm::WaitingTaskWithArenaHolder h2(std::move(h));
0031                  pool.runAsync(std::move(h2), [&count]() { ++count; }, errorContext);
0032                }) |
0033                lastTask(edm::WaitingTaskHolder(group, &waitTask));
0034       h.doneWaiting(std::exception_ptr());
0035     }
0036     waitTask.waitNoThrow();
0037     REQUIRE(count.load() == 1);
0038     REQUIRE(waitTask.done());
0039     REQUIRE(not waitTask.exceptionPtr());
0040   }
0041 
0042   SECTION("Two async calls") {
0043     std::atomic<int> count{0};
0044     std::atomic<bool> mayContinue{false};
0045 
0046     oneapi::tbb::task_group group;
0047     edm::FinalWaitingTask waitTask{group};
0048 
0049     {
0050       using namespace edm::waiting_task::chain;
0051       auto h = first([&pool, &count, &mayContinue](edm::WaitingTaskHolder h) {
0052                  edm::WaitingTaskWithArenaHolder h2(std::move(h));
0053                  pool.runAsync(
0054                      h2,
0055                      [&count, &mayContinue]() {
0056                        while (not mayContinue) {
0057                          hardware_pause();
0058                        }
0059                        using namespace std::chrono_literals;
0060                        std::this_thread::sleep_for(10ms);
0061                        ++count;
0062                      },
0063                      errorContext);
0064                  pool.runAsync(h2, [&count]() { ++count; }, errorContext);
0065                }) |
0066                lastTask(edm::WaitingTaskHolder(group, &waitTask));
0067       h.doneWaiting(std::exception_ptr());
0068     }
0069     mayContinue = true;
0070     waitTask.waitNoThrow();
0071     REQUIRE(count.load() == 2);
0072     REQUIRE(waitTask.done());
0073     REQUIRE(not waitTask.exceptionPtr());
0074   }
0075 
0076   SECTION("Concurrent async calls") {
0077     std::atomic<int> count{0};
0078     std::atomic<int> mayContinue{0};
0079 
0080     oneapi::tbb::task_group group;
0081     edm::FinalWaitingTask waitTask{group};
0082 
0083     {
0084       using namespace edm::waiting_task::chain;
0085       auto h1 = first([&pool, &count, &mayContinue](edm::WaitingTaskHolder h) {
0086                   edm::WaitingTaskWithArenaHolder h2(std::move(h));
0087                   ++mayContinue;
0088                   while (mayContinue != 2) {
0089                     hardware_pause();
0090                   }
0091                   pool.runAsync(h2, [&count]() { ++count; }, errorContext);
0092                 }) |
0093                 lastTask(edm::WaitingTaskHolder(group, &waitTask));
0094 
0095       auto h2 = first([&pool, &count, &mayContinue](edm::WaitingTaskHolder h) {
0096                   edm::WaitingTaskWithArenaHolder h2(std::move(h));
0097                   ++mayContinue;
0098                   while (mayContinue != 2) {
0099                     hardware_pause();
0100                   }
0101                   pool.runAsync(h2, [&count]() { ++count; }, errorContext);
0102                 }) |
0103                 lastTask(edm::WaitingTaskHolder(group, &waitTask));
0104       h2.doneWaiting(std::exception_ptr());
0105       h1.doneWaiting(std::exception_ptr());
0106     }
0107     waitTask.waitNoThrow();
0108     REQUIRE(count.load() == 2);
0109     REQUIRE(waitTask.done());
0110     REQUIRE(not waitTask.exceptionPtr());
0111   }
0112 
0113   SECTION("Exceptions") {
0114     SECTION("One async call") {
0115       std::atomic<int> count{0};
0116 
0117       oneapi::tbb::task_group group;
0118       edm::FinalWaitingTask waitTask{group};
0119 
0120       {
0121         using namespace edm::waiting_task::chain;
0122         auto h = first([&pool](edm::WaitingTaskHolder h) {
0123                    edm::WaitingTaskWithArenaHolder h2(std::move(h));
0124                    pool.runAsync(std::move(h2), []() { throw std::runtime_error("error"); }, errorContext);
0125                  }) |
0126                  lastTask(edm::WaitingTaskHolder(group, &waitTask));
0127         h.doneWaiting(std::exception_ptr());
0128       }
0129       REQUIRE_THROWS_WITH(
0130           waitTask.wait(),
0131           Catch::Contains("error") and Catch::Contains("StdException") and Catch::Contains("WaitingThreadPool test"));
0132       REQUIRE(count.load() == 0);
0133     }
0134 
0135     SECTION("Two async calls") {
0136       std::atomic<int> count{0};
0137 
0138       oneapi::tbb::task_group group;
0139       edm::FinalWaitingTask waitTask{group};
0140 
0141       {
0142         using namespace edm::waiting_task::chain;
0143         auto h = first([&pool, &count](edm::WaitingTaskHolder h) {
0144                    edm::WaitingTaskWithArenaHolder h2(std::move(h));
0145                    pool.runAsync(
0146                        h2,
0147                        [&count]() {
0148                          if (count.fetch_add(1) == 0) {
0149                            throw cms::Exception("error 1");
0150                          }
0151                          ++count;
0152                        },
0153                        errorContext);
0154                    pool.runAsync(
0155                        h2,
0156                        [&count]() {
0157                          if (count.fetch_add(1) == 0) {
0158                            throw cms::Exception("error 2");
0159                          }
0160                          ++count;
0161                        },
0162                        errorContext);
0163                  }) |
0164                  lastTask(edm::WaitingTaskHolder(group, &waitTask));
0165         h.doneWaiting(std::exception_ptr());
0166       }
0167       REQUIRE_THROWS_AS(waitTask.wait(), cms::Exception);
0168       REQUIRE(count.load() == 3);
0169     }
0170 
0171     SECTION("Two exceptions") {
0172       std::atomic<int> count{0};
0173 
0174       oneapi::tbb::task_group group;
0175       edm::FinalWaitingTask waitTask{group};
0176 
0177       {
0178         using namespace edm::waiting_task::chain;
0179         auto h = first([&pool, &count](edm::WaitingTaskHolder h) {
0180                    edm::WaitingTaskWithArenaHolder h2(std::move(h));
0181                    pool.runAsync(
0182                        h2,
0183                        [&count]() {
0184                          ++count;
0185                          throw cms::Exception("error 1");
0186                        },
0187                        errorContext);
0188                    pool.runAsync(
0189                        h2,
0190                        [&count]() {
0191                          ++count;
0192                          throw cms::Exception("error 2");
0193                        },
0194                        errorContext);
0195                  }) |
0196                  lastTask(edm::WaitingTaskHolder(group, &waitTask));
0197         h.doneWaiting(std::exception_ptr());
0198       }
0199       REQUIRE_THROWS_AS(waitTask.wait(), cms::Exception);
0200       REQUIRE(count.load() == 2);
0201     }
0202 
0203     SECTION("Concurrent exceptions") {
0204       std::atomic<int> count{0};
0205 
0206       oneapi::tbb::task_group group;
0207       edm::FinalWaitingTask waitTask{group};
0208 
0209       {
0210         using namespace edm::waiting_task::chain;
0211         auto h1 = first([&pool, &count](edm::WaitingTaskHolder h) {
0212                     edm::WaitingTaskWithArenaHolder h2(std::move(h));
0213                     pool.runAsync(
0214                         h2,
0215                         [&count]() {
0216                           ++count;
0217                           throw cms::Exception("error 1");
0218                         },
0219                         errorContext);
0220                   }) |
0221                   lastTask(edm::WaitingTaskHolder(group, &waitTask));
0222 
0223         auto h2 = first([&pool, &count](edm::WaitingTaskHolder h) {
0224                     edm::WaitingTaskWithArenaHolder h2(std::move(h));
0225                     pool.runAsync(
0226                         h2,
0227                         [&count]() {
0228                           ++count;
0229                           throw cms::Exception("error 2");
0230                         },
0231                         errorContext);
0232                   }) |
0233                   lastTask(edm::WaitingTaskHolder(group, &waitTask));
0234         h2.doneWaiting(std::exception_ptr());
0235         h1.doneWaiting(std::exception_ptr());
0236       }
0237       REQUIRE_THROWS_AS(waitTask.wait(), cms::Exception);
0238       REQUIRE(count.load() == 2);
0239     }
0240 
0241     SECTION("Concurrent exception and success") {
0242       std::atomic<int> count{0};
0243 
0244       oneapi::tbb::task_group group;
0245       edm::FinalWaitingTask waitTask{group};
0246 
0247       {
0248         using namespace edm::waiting_task::chain;
0249         auto h1 = first([&pool, &count](edm::WaitingTaskHolder h) {
0250                     edm::WaitingTaskWithArenaHolder h2(std::move(h));
0251                     pool.runAsync(
0252                         h2,
0253                         [&count]() {
0254                           ++count;
0255                           throw cms::Exception("error 1");
0256                         },
0257                         errorContext);
0258                   }) |
0259                   lastTask(edm::WaitingTaskHolder(group, &waitTask));
0260 
0261         auto h2 = first([&pool, &count](edm::WaitingTaskHolder h) {
0262                     edm::WaitingTaskWithArenaHolder h2(std::move(h));
0263                     pool.runAsync(h2, [&count]() { ++count; }, errorContext);
0264                   }) |
0265                   lastTask(edm::WaitingTaskHolder(group, &waitTask));
0266         h2.doneWaiting(std::exception_ptr());
0267         h1.doneWaiting(std::exception_ptr());
0268       }
0269       REQUIRE_THROWS_AS(waitTask.wait(), cms::Exception);
0270       REQUIRE(count.load() == 2);
0271     }
0272   }
0273 }