File indexing completed on 2024-09-07 04:36:21
0001 #include "catch.hpp"
0002
0003 #include <atomic>
0004
0005 #include "oneapi/tbb/global_control.h"
0006
0007 #include "FWCore/Concurrency/interface/chain_first.h"
0008 #include "FWCore/Concurrency/interface/FinalWaitingTask.h"
0009 #include "FWCore/Concurrency/interface/Async.h"
0010
0011 namespace {
0012 constexpr char const* errorContext() { return "AsyncServiceTest"; }
0013
0014 class AsyncServiceTest : public edm::Async {
0015 public:
0016 enum class State { kAllowed, kDisallowed, kShutdown };
0017
0018 AsyncServiceTest() = default;
0019
0020 void setAllowed(bool allowed) noexcept { allowed_ = allowed; }
0021
0022 private:
0023 void ensureAllowed() const final {
0024 if (not allowed_) {
0025 throw std::runtime_error("Calling run in this context is not allowed");
0026 }
0027 }
0028
0029 std::atomic<bool> allowed_ = true;
0030 };
0031 }
0032
0033 TEST_CASE("Test Async", "[edm::Async") {
0034
0035
0036
0037
0038 oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, 2);
0039
0040 SECTION("Normal operation") {
0041 AsyncServiceTest service;
0042 std::atomic<int> count{0};
0043
0044 oneapi::tbb::task_group group;
0045 edm::FinalWaitingTask waitTask{group};
0046
0047 {
0048 using namespace edm::waiting_task::chain;
0049 auto h1 = first([&service, &count](edm::WaitingTaskHolder h) {
0050 edm::WaitingTaskWithArenaHolder h2(std::move(h));
0051 service.runAsync(h2, [&count]() { ++count; }, errorContext);
0052 }) |
0053 lastTask(edm::WaitingTaskHolder(group, &waitTask));
0054
0055 auto h2 = first([&service, &count](edm::WaitingTaskHolder h) {
0056 edm::WaitingTaskWithArenaHolder h2(std::move(h));
0057 service.runAsync(h2, [&count]() { ++count; }, errorContext);
0058 }) |
0059 lastTask(edm::WaitingTaskHolder(group, &waitTask));
0060 h2.doneWaiting(std::exception_ptr());
0061 h1.doneWaiting(std::exception_ptr());
0062 }
0063 waitTask.waitNoThrow();
0064 REQUIRE(count.load() == 2);
0065 REQUIRE(waitTask.done());
0066 REQUIRE(not waitTask.exceptionPtr());
0067 }
0068
0069 SECTION("Disallowed") {
0070 AsyncServiceTest service;
0071 std::atomic<int> count{0};
0072
0073 oneapi::tbb::task_group group;
0074 edm::FinalWaitingTask waitTask{group};
0075
0076 {
0077 using namespace edm::waiting_task::chain;
0078 auto h = first([&service, &count](edm::WaitingTaskHolder h) {
0079 edm::WaitingTaskWithArenaHolder h2(std::move(h));
0080 service.runAsync(h2, [&count]() { ++count; }, errorContext);
0081 service.setAllowed(false);
0082 }) |
0083 then([&service, &count](edm::WaitingTaskHolder h) {
0084 edm::WaitingTaskWithArenaHolder h2(std::move(h));
0085 service.runAsync(h2, [&count]() { ++count; }, errorContext);
0086 }) |
0087 lastTask(edm::WaitingTaskHolder(group, &waitTask));
0088 h.doneWaiting(std::exception_ptr());
0089 }
0090 waitTask.waitNoThrow();
0091 REQUIRE(count.load() == 1);
0092 REQUIRE(waitTask.done());
0093 REQUIRE(waitTask.exceptionPtr());
0094 REQUIRE_THROWS_WITH(std::rethrow_exception(waitTask.exceptionPtr()),
0095 Catch::Contains("Calling run in this context is not allowed"));
0096 }
0097 }