Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-06-07 02:29:34

0001 #include "catch.hpp"
0002 
0003 #include <atomic>
0004 
0005 #include "oneapi/tbb/global_control.h"
0006 
0007 #include "FWCore/Concurrency/interface/chain_first.h"
0008 #include "FWCore/Concurrency/interface/FinalWaitingTask.h"
0009 #include "FWCore/Concurrency/interface/Async.h"
0010 
0011 namespace {
0012   constexpr char const* errorContext() { return "AsyncServiceTest"; }
0013 
0014   class AsyncServiceTest : public edm::Async {
0015   public:
0016     enum class State { kAllowed, kDisallowed, kShutdown };
0017 
0018     AsyncServiceTest() = default;
0019 
0020     void setAllowed(bool allowed) noexcept { allowed_ = allowed; }
0021 
0022   private:
0023     void ensureAllowed() const final {
0024       if (not allowed_) {
0025         throw std::runtime_error("Calling run in this context is not allowed");
0026       }
0027     }
0028 
0029     std::atomic<bool> allowed_ = true;
0030   };
0031 }  // namespace
0032 
0033 TEST_CASE("Test Async", "[edm::Async") {
0034   // Using parallelism 2 here because otherwise the
0035   // tbb::task_arena::enqueue() in WaitingTaskWithArenaHolder will
0036   // start a new TBB thread that "inherits" the name from the
0037   // WaitingThreadPool thread.
0038   oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, 2);
0039 
0040   SECTION("Normal operation") {
0041     AsyncServiceTest service;
0042     std::atomic<int> count{0};
0043 
0044     oneapi::tbb::task_group group;
0045     edm::FinalWaitingTask waitTask{group};
0046 
0047     {
0048       using namespace edm::waiting_task::chain;
0049       auto h1 = first([&service, &count](edm::WaitingTaskHolder h) {
0050                   edm::WaitingTaskWithArenaHolder h2(std::move(h));
0051                   service.runAsync(
0052                       h2, [&count]() { ++count; }, errorContext);
0053                 }) |
0054                 lastTask(edm::WaitingTaskHolder(group, &waitTask));
0055 
0056       auto h2 = first([&service, &count](edm::WaitingTaskHolder h) {
0057                   edm::WaitingTaskWithArenaHolder h2(std::move(h));
0058                   service.runAsync(
0059                       h2, [&count]() { ++count; }, errorContext);
0060                 }) |
0061                 lastTask(edm::WaitingTaskHolder(group, &waitTask));
0062       h2.doneWaiting(std::exception_ptr());
0063       h1.doneWaiting(std::exception_ptr());
0064     }
0065     waitTask.waitNoThrow();
0066     REQUIRE(count.load() == 2);
0067     REQUIRE(waitTask.done());
0068     REQUIRE(not waitTask.exceptionPtr());
0069   }
0070 
0071   SECTION("Disallowed") {
0072     AsyncServiceTest service;
0073     std::atomic<int> count{0};
0074 
0075     oneapi::tbb::task_group group;
0076     edm::FinalWaitingTask waitTask{group};
0077 
0078     {
0079       using namespace edm::waiting_task::chain;
0080       auto h = first([&service, &count](edm::WaitingTaskHolder h) {
0081                  edm::WaitingTaskWithArenaHolder h2(std::move(h));
0082                  service.runAsync(
0083                      h2, [&count]() { ++count; }, errorContext);
0084                  service.setAllowed(false);
0085                }) |
0086                then([&service, &count](edm::WaitingTaskHolder h) {
0087                  edm::WaitingTaskWithArenaHolder h2(std::move(h));
0088                  service.runAsync(
0089                      h2, [&count]() { ++count; }, errorContext);
0090                }) |
0091                lastTask(edm::WaitingTaskHolder(group, &waitTask));
0092       h.doneWaiting(std::exception_ptr());
0093     }
0094     waitTask.waitNoThrow();
0095     REQUIRE(count.load() == 1);
0096     REQUIRE(waitTask.done());
0097     REQUIRE(waitTask.exceptionPtr());
0098     REQUIRE_THROWS_WITH(std::rethrow_exception(waitTask.exceptionPtr()),
0099                         Catch::Contains("Calling run in this context is not allowed"));
0100   }
0101 }