File indexing completed on 2024-04-06 12:15:40
0001 #ifndef HeterogeneousCore_AlpakaInterface_interface_HostOnlyTask_h
0002 #define HeterogeneousCore_AlpakaInterface_interface_HostOnlyTask_h
0003
0004 #include <functional>
0005 #include <memory>
0006
0007 #include <fmt/format.h>
0008
0009 #include <alpaka/alpaka.hpp>
0010
0011 namespace alpaka {
0012
0013
0014
0015
0016
0017 class HostOnlyTask {
0018 public:
0019 HostOnlyTask(std::function<void(std::exception_ptr)> task) : task_(std::move(task)) {}
0020
0021 void operator()(std::exception_ptr eptr) const { task_(eptr); }
0022
0023 private:
0024 std::function<void(std::exception_ptr)> task_;
0025 };
0026
0027 namespace trait {
0028
0029 #ifdef ALPAKA_ACC_GPU_CUDA_ENABLED
0030
0031 template <>
0032 struct Enqueue<QueueCudaRtNonBlocking, HostOnlyTask> {
0033 using TApi = ApiCudaRt;
0034
0035 static void CUDART_CB callback(cudaStream_t queue, cudaError_t status, void* arg) {
0036 std::unique_ptr<HostOnlyTask> pTask(static_cast<HostOnlyTask*>(arg));
0037 if (status == cudaSuccess) {
0038 (*pTask)(nullptr);
0039 } else {
0040
0041 try {
0042 throw std::runtime_error(fmt::format("CUDA error: callback of stream {} received error {}: {}.",
0043 fmt::ptr(queue),
0044 cudaGetErrorName(status),
0045 cudaGetErrorString(status)));
0046 } catch (std::exception&) {
0047
0048 (*pTask)(std::current_exception());
0049 }
0050 }
0051 }
0052
0053 ALPAKA_FN_HOST static auto enqueue(QueueCudaRtNonBlocking& queue, HostOnlyTask task) -> void {
0054 auto pTask = std::make_unique<HostOnlyTask>(std::move(task));
0055 ALPAKA_UNIFORM_CUDA_HIP_RT_CHECK(
0056 cudaStreamAddCallback(alpaka::getNativeHandle(queue), callback, static_cast<void*>(pTask.release()), 0u));
0057 }
0058 };
0059 #endif
0060
0061 #ifdef ALPAKA_ACC_GPU_HIP_ENABLED
0062
0063 template <>
0064 struct Enqueue<QueueHipRtNonBlocking, HostOnlyTask> {
0065 using TApi = ApiHipRt;
0066
0067 static void callback(hipStream_t queue, hipError_t status, void* arg) {
0068 std::unique_ptr<HostOnlyTask> pTask(static_cast<HostOnlyTask*>(arg));
0069 if (status == hipSuccess) {
0070 (*pTask)(nullptr);
0071 } else {
0072
0073 try {
0074 throw std::runtime_error(fmt::format("HIP error: callback of stream {} received error {}: {}.",
0075 fmt::ptr(queue),
0076 hipGetErrorName(status),
0077 hipGetErrorString(status)));
0078 } catch (std::exception&) {
0079
0080 (*pTask)(std::current_exception());
0081 }
0082 }
0083 }
0084
0085 ALPAKA_FN_HOST static auto enqueue(QueueHipRtNonBlocking& queue, HostOnlyTask task) -> void {
0086 auto pTask = std::make_unique<HostOnlyTask>(std::move(task));
0087 ALPAKA_UNIFORM_CUDA_HIP_RT_CHECK(
0088 hipStreamAddCallback(alpaka::getNativeHandle(queue), callback, static_cast<void*>(pTask.release()), 0u));
0089 }
0090 };
0091 #endif
0092
0093 }
0094
0095 }
0096
0097 #endif