Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-07-03 04:17:57

0001 #ifndef HeterogeneousCore_AlpakaCore_interface_atomicMaxPair_h
0002 #define HeterogeneousCore_AlpakaCore_interface_atomicMaxPair_h
0003 #include <alpaka/alpaka.hpp>
0004 
0005 #include "FWCore/Utilities/interface/bit_cast.h"
0006 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0007 
0008 // Note: Does not compile with ALPAKA_FN_ACC on ROCm
0009 template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>, typename F>
0010 ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE void atomicMaxPair(const TAcc& acc,
0011                                                        unsigned long long int* address,
0012                                                        std::pair<unsigned int, float> value,
0013                                                        F comparator) {
0014 #if defined(__CUDA_ARCH__) or defined(__HIP_DEVICE_COMPILE__)
0015   unsigned long long int val = (static_cast<unsigned long long int>(value.first) << 32) + __float_as_uint(value.second);
0016   unsigned long long int ret = *address;
0017   while (comparator(value,
0018                     std::pair<unsigned int, float>{static_cast<unsigned int>(ret >> 32 & 0xffffffff),
0019                                                    __uint_as_float(ret & 0xffffffff)})) {
0020     unsigned long long int old = ret;
0021     if ((ret = atomicCAS(address, old, val)) == old)
0022       break;
0023   }
0024 #else
0025   unsigned long long int val =
0026       (static_cast<unsigned long long int>(value.first) << 32) + edm::bit_cast<unsigned int>(value.second);
0027   unsigned long long int ret = *address;
0028   while (comparator(value,
0029                     std::pair{static_cast<unsigned int>(ret >> 32 & 0xffffffff),
0030                               edm::bit_cast<float>(static_cast<unsigned int>(ret & 0xffffffff))})) {
0031     unsigned long long int old = ret;
0032     if ((ret = alpaka::atomicCas(acc, address, old, val)) == old)
0033       break;
0034   }
0035 #endif  // __CUDA_ARCH__ or __HIP_DEVICE_COMPILE__
0036 }
0037 
0038 #endif  // HeterogeneousCore_AlpakaCore_interface_atomicMaxPair_h