Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:40

0001 #ifndef HeterogeneousCore_AlpakaInterface_interface_HostOnlyTask_h
0002 #define HeterogeneousCore_AlpakaInterface_interface_HostOnlyTask_h
0003 
0004 #include <functional>
0005 #include <memory>
0006 
0007 #include <fmt/format.h>
0008 
0009 #include <alpaka/alpaka.hpp>
0010 
0011 namespace alpaka {
0012 
0013   //! A task that is guaranted not to call any GPU-ralated APIs
0014   //!
0015   //! These tasks can be enqueued directly to the native GPU queues, without the use of a
0016   //! dedicated host-side worker thread.
0017   class HostOnlyTask {
0018   public:
0019     HostOnlyTask(std::function<void(std::exception_ptr)> task) : task_(std::move(task)) {}
0020 
0021     void operator()(std::exception_ptr eptr) const { task_(eptr); }
0022 
0023   private:
0024     std::function<void(std::exception_ptr)> task_;
0025   };
0026 
0027   namespace trait {
0028 
0029 #ifdef ALPAKA_ACC_GPU_CUDA_ENABLED
0030     //! The CUDA async queue enqueue trait specialization for "safe tasks"
0031     template <>
0032     struct Enqueue<QueueCudaRtNonBlocking, HostOnlyTask> {
0033       using TApi = ApiCudaRt;
0034 
0035       static void CUDART_CB callback(cudaStream_t queue, cudaError_t status, void* arg) {
0036         std::unique_ptr<HostOnlyTask> pTask(static_cast<HostOnlyTask*>(arg));
0037         if (status == cudaSuccess) {
0038           (*pTask)(nullptr);
0039         } else {
0040           // wrap the exception in a try-catch block to let GDB "catch throw" break on it
0041           try {
0042             throw std::runtime_error(fmt::format("CUDA error: callback of stream {} received error {}: {}.",
0043                                                  fmt::ptr(queue),
0044                                                  cudaGetErrorName(status),
0045                                                  cudaGetErrorString(status)));
0046           } catch (std::exception&) {
0047             // pass the exception to the task
0048             (*pTask)(std::current_exception());
0049           }
0050         }
0051       }
0052 
0053       ALPAKA_FN_HOST static auto enqueue(QueueCudaRtNonBlocking& queue, HostOnlyTask task) -> void {
0054         auto pTask = std::make_unique<HostOnlyTask>(std::move(task));
0055         ALPAKA_UNIFORM_CUDA_HIP_RT_CHECK(
0056             cudaStreamAddCallback(alpaka::getNativeHandle(queue), callback, static_cast<void*>(pTask.release()), 0u));
0057       }
0058     };
0059 #endif  // ALPAKA_ACC_GPU_CUDA_ENABLED
0060 
0061 #ifdef ALPAKA_ACC_GPU_HIP_ENABLED
0062     //! The HIP async queue enqueue trait specialization for "safe tasks"
0063     template <>
0064     struct Enqueue<QueueHipRtNonBlocking, HostOnlyTask> {
0065       using TApi = ApiHipRt;
0066 
0067       static void callback(hipStream_t queue, hipError_t status, void* arg) {
0068         std::unique_ptr<HostOnlyTask> pTask(static_cast<HostOnlyTask*>(arg));
0069         if (status == hipSuccess) {
0070           (*pTask)(nullptr);
0071         } else {
0072           // wrap the exception in a try-catch block to let GDB "catch throw" break on it
0073           try {
0074             throw std::runtime_error(fmt::format("HIP error: callback of stream {} received error {}: {}.",
0075                                                  fmt::ptr(queue),
0076                                                  hipGetErrorName(status),
0077                                                  hipGetErrorString(status)));
0078           } catch (std::exception&) {
0079             // pass the exception to the task
0080             (*pTask)(std::current_exception());
0081           }
0082         }
0083       }
0084 
0085       ALPAKA_FN_HOST static auto enqueue(QueueHipRtNonBlocking& queue, HostOnlyTask task) -> void {
0086         auto pTask = std::make_unique<HostOnlyTask>(std::move(task));
0087         ALPAKA_UNIFORM_CUDA_HIP_RT_CHECK(
0088             hipStreamAddCallback(alpaka::getNativeHandle(queue), callback, static_cast<void*>(pTask.release()), 0u));
0089       }
0090     };
0091 #endif  // ALPAKA_ACC_GPU_HIP_ENABLED
0092 
0093   }  // namespace trait
0094 
0095 }  // namespace alpaka
0096 
0097 #endif  // HeterogeneousCore_AlpakaInterface_interface_HostOnlyTask_h