Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-09-07 04:36:21

0001 #include "catch.hpp"
0002 
0003 #include <atomic>
0004 
0005 #include "oneapi/tbb/global_control.h"
0006 
0007 #include "FWCore/Concurrency/interface/chain_first.h"
0008 #include "FWCore/Concurrency/interface/FinalWaitingTask.h"
0009 #include "FWCore/Concurrency/interface/Async.h"
0010 
0011 namespace {
0012   constexpr char const* errorContext() { return "AsyncServiceTest"; }
0013 
0014   class AsyncServiceTest : public edm::Async {
0015   public:
0016     enum class State { kAllowed, kDisallowed, kShutdown };
0017 
0018     AsyncServiceTest() = default;
0019 
0020     void setAllowed(bool allowed) noexcept { allowed_ = allowed; }
0021 
0022   private:
0023     void ensureAllowed() const final {
0024       if (not allowed_) {
0025         throw std::runtime_error("Calling run in this context is not allowed");
0026       }
0027     }
0028 
0029     std::atomic<bool> allowed_ = true;
0030   };
0031 }  // namespace
0032 
0033 TEST_CASE("Test Async", "[edm::Async") {
0034   // Using parallelism 2 here because otherwise the
0035   // tbb::task_arena::enqueue() in WaitingTaskWithArenaHolder will
0036   // start a new TBB thread that "inherits" the name from the
0037   // WaitingThreadPool thread.
0038   oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, 2);
0039 
0040   SECTION("Normal operation") {
0041     AsyncServiceTest service;
0042     std::atomic<int> count{0};
0043 
0044     oneapi::tbb::task_group group;
0045     edm::FinalWaitingTask waitTask{group};
0046 
0047     {
0048       using namespace edm::waiting_task::chain;
0049       auto h1 = first([&service, &count](edm::WaitingTaskHolder h) {
0050                   edm::WaitingTaskWithArenaHolder h2(std::move(h));
0051                   service.runAsync(h2, [&count]() { ++count; }, errorContext);
0052                 }) |
0053                 lastTask(edm::WaitingTaskHolder(group, &waitTask));
0054 
0055       auto h2 = first([&service, &count](edm::WaitingTaskHolder h) {
0056                   edm::WaitingTaskWithArenaHolder h2(std::move(h));
0057                   service.runAsync(h2, [&count]() { ++count; }, errorContext);
0058                 }) |
0059                 lastTask(edm::WaitingTaskHolder(group, &waitTask));
0060       h2.doneWaiting(std::exception_ptr());
0061       h1.doneWaiting(std::exception_ptr());
0062     }
0063     waitTask.waitNoThrow();
0064     REQUIRE(count.load() == 2);
0065     REQUIRE(waitTask.done());
0066     REQUIRE(not waitTask.exceptionPtr());
0067   }
0068 
0069   SECTION("Disallowed") {
0070     AsyncServiceTest service;
0071     std::atomic<int> count{0};
0072 
0073     oneapi::tbb::task_group group;
0074     edm::FinalWaitingTask waitTask{group};
0075 
0076     {
0077       using namespace edm::waiting_task::chain;
0078       auto h = first([&service, &count](edm::WaitingTaskHolder h) {
0079                  edm::WaitingTaskWithArenaHolder h2(std::move(h));
0080                  service.runAsync(h2, [&count]() { ++count; }, errorContext);
0081                  service.setAllowed(false);
0082                }) |
0083                then([&service, &count](edm::WaitingTaskHolder h) {
0084                  edm::WaitingTaskWithArenaHolder h2(std::move(h));
0085                  service.runAsync(h2, [&count]() { ++count; }, errorContext);
0086                }) |
0087                lastTask(edm::WaitingTaskHolder(group, &waitTask));
0088       h.doneWaiting(std::exception_ptr());
0089     }
0090     waitTask.waitNoThrow();
0091     REQUIRE(count.load() == 1);
0092     REQUIRE(waitTask.done());
0093     REQUIRE(waitTask.exceptionPtr());
0094     REQUIRE_THROWS_WITH(std::rethrow_exception(waitTask.exceptionPtr()),
0095                         Catch::Contains("Calling run in this context is not allowed"));
0096   }
0097 }