AsyncServiceTest

State

Line Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
#include "catch.hpp"

#include <atomic>

#include "oneapi/tbb/global_control.h"

#include "FWCore/Concurrency/interface/chain_first.h"
#include "FWCore/Concurrency/interface/FinalWaitingTask.h"
#include "FWCore/Concurrency/interface/Async.h"

namespace {
  constexpr char const* errorContext() { return "AsyncServiceTest"; }

  class AsyncServiceTest : public edm::Async {
  public:
    enum class State { kAllowed, kDisallowed, kShutdown };

    AsyncServiceTest() = default;

    void setAllowed(bool allowed) noexcept { allowed_ = allowed; }

  private:
    void ensureAllowed() const final {
      if (not allowed_) {
        throw std::runtime_error("Calling run in this context is not allowed");
      }
    }

    std::atomic<bool> allowed_ = true;
  };
}  // namespace

TEST_CASE("Test Async", "[edm::Async") {
  // Using parallelism 2 here because otherwise the
  // tbb::task_arena::enqueue() in WaitingTaskWithArenaHolder will
  // start a new TBB thread that "inherits" the name from the
  // WaitingThreadPool thread.
  oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, 2);

  SECTION("Normal operation") {
    AsyncServiceTest service;
    std::atomic<int> count{0};

    oneapi::tbb::task_group group;
    edm::FinalWaitingTask waitTask{group};

    {
      using namespace edm::waiting_task::chain;
      auto h1 = first([&service, &count](edm::WaitingTaskHolder h) {
                  edm::WaitingTaskWithArenaHolder h2(std::move(h));
                  service.runAsync(h2, [&count]() { ++count; }, errorContext);
                }) |
                lastTask(edm::WaitingTaskHolder(group, &waitTask));

      auto h2 = first([&service, &count](edm::WaitingTaskHolder h) {
                  edm::WaitingTaskWithArenaHolder h2(std::move(h));
                  service.runAsync(h2, [&count]() { ++count; }, errorContext);
                }) |
                lastTask(edm::WaitingTaskHolder(group, &waitTask));
      h2.doneWaiting(std::exception_ptr());
      h1.doneWaiting(std::exception_ptr());
    }
    waitTask.waitNoThrow();
    REQUIRE(count.load() == 2);
    REQUIRE(waitTask.done());
    REQUIRE(not waitTask.exceptionPtr());
  }

  SECTION("Disallowed") {
    AsyncServiceTest service;
    std::atomic<int> count{0};

    oneapi::tbb::task_group group;
    edm::FinalWaitingTask waitTask{group};

    {
      using namespace edm::waiting_task::chain;
      auto h = first([&service, &count](edm::WaitingTaskHolder h) {
                 edm::WaitingTaskWithArenaHolder h2(std::move(h));
                 service.runAsync(h2, [&count]() { ++count; }, errorContext);
                 service.setAllowed(false);
               }) |
               then([&service, &count](edm::WaitingTaskHolder h) {
                 edm::WaitingTaskWithArenaHolder h2(std::move(h));
                 service.runAsync(h2, [&count]() { ++count; }, errorContext);
               }) |
               lastTask(edm::WaitingTaskHolder(group, &waitTask));
      h.doneWaiting(std::exception_ptr());
    }
    waitTask.waitNoThrow();
    REQUIRE(count.load() == 1);
    REQUIRE(waitTask.done());
    REQUIRE(waitTask.exceptionPtr());
    REQUIRE_THROWS_WITH(std::rethrow_exception(waitTask.exceptionPtr()),
                        Catch::Contains("Calling run in this context is not allowed"));
  }
}