File indexing completed on 2024-04-06 12:15:40
0001 #ifndef HeterogeneousCore_AlpakaInterface_interface_prefixScan_h
0002 #define HeterogeneousCore_AlpakaInterface_interface_prefixScan_h
0003
0004 #include <alpaka/alpaka.hpp>
0005
0006 #include "FWCore/Utilities/interface/CMSUnrollLoop.h"
0007 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0008 #include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
0009 namespace cms::alpakatools {
0010 template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
0011 constexpr bool isPowerOf2(T v) {
0012
0013 while (v) {
0014 if (v & 1)
0015 return !(v >> 1);
0016 else
0017 v >>= 1;
0018 }
0019 return false;
0020 }
0021
0022 template <typename TAcc, typename T, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>>
0023 ALPAKA_FN_ACC ALPAKA_FN_INLINE void warpPrefixScan(
0024 const TAcc& acc, int32_t laneId, T const* ci, T* co, uint32_t i, bool active = true) {
0025
0026 T x = active ? ci[i] : 0;
0027 CMS_UNROLL_LOOP
0028 for (int32_t offset = 1; offset < alpaka::warp::getSize(acc); offset <<= 1) {
0029
0030 using dataType = std::conditional_t<std::is_floating_point_v<T>, T, std::int32_t>;
0031 T y = alpaka::warp::shfl(acc, static_cast<dataType>(x), laneId - offset);
0032 if (laneId >= offset)
0033 x += y;
0034 }
0035 if (active)
0036 co[i] = x;
0037 }
0038
0039 template <typename TAcc, typename T, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>>
0040 ALPAKA_FN_ACC ALPAKA_FN_INLINE void warpPrefixScan(
0041 const TAcc& acc, int32_t laneId, T* c, uint32_t i, bool active = true) {
0042 warpPrefixScan(acc, laneId, c, c, i, active);
0043 }
0044
0045
0046 template <typename TAcc, typename T>
0047 ALPAKA_FN_ACC ALPAKA_FN_INLINE void blockPrefixScan(
0048 const TAcc& acc, T const* ci, T* co, int32_t size, T* ws = nullptr) {
0049 if constexpr (!requires_single_thread_per_block_v<TAcc>) {
0050 const auto warpSize = alpaka::warp::getSize(acc);
0051 int32_t const blockDimension(alpaka::getWorkDiv<alpaka::Block, alpaka::Threads>(acc)[0u]);
0052 int32_t const blockThreadIdx(alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u]);
0053 ALPAKA_ASSERT_ACC(ws);
0054 ALPAKA_ASSERT_ACC(size <= warpSize * warpSize);
0055 ALPAKA_ASSERT_ACC(0 == blockDimension % warpSize);
0056 auto first = blockThreadIdx;
0057 ALPAKA_ASSERT_ACC(isPowerOf2(warpSize));
0058 auto laneId = blockThreadIdx & (warpSize - 1);
0059 auto warpUpRoundedSize = (size + warpSize - 1) / warpSize * warpSize;
0060
0061 for (auto i = first; i < warpUpRoundedSize; i += blockDimension) {
0062
0063 warpPrefixScan(acc, laneId, ci, co, i, i < size);
0064 if (i < size) {
0065
0066 auto warpId = i / warpSize;
0067 ALPAKA_ASSERT_ACC(warpId < warpSize);
0068 if ((warpSize - 1) == laneId)
0069 ws[warpId] = co[i];
0070 }
0071 }
0072 alpaka::syncBlockThreads(acc);
0073 if (size <= warpSize)
0074 return;
0075 if (blockThreadIdx < warpSize) {
0076 warpPrefixScan(acc, laneId, ws, blockThreadIdx);
0077 }
0078 alpaka::syncBlockThreads(acc);
0079 for (auto i = first + warpSize; i < size; i += blockDimension) {
0080 int32_t warpId = i / warpSize;
0081 co[i] += ws[warpId - 1];
0082 }
0083 alpaka::syncBlockThreads(acc);
0084 } else {
0085 co[0] = ci[0];
0086 for (int32_t i = 1; i < size; ++i)
0087 co[i] = ci[i] + co[i - 1];
0088 }
0089 }
0090
0091 template <typename TAcc, typename T>
0092 ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE void blockPrefixScan(const TAcc& acc,
0093 T* __restrict__ c,
0094 int32_t size,
0095 T* __restrict__ ws = nullptr) {
0096 if constexpr (!requires_single_thread_per_block_v<TAcc>) {
0097 const auto warpSize = alpaka::warp::getSize(acc);
0098 int32_t const blockDimension(alpaka::getWorkDiv<alpaka::Block, alpaka::Threads>(acc)[0u]);
0099 int32_t const blockThreadIdx(alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u]);
0100 ALPAKA_ASSERT_ACC(ws);
0101 ALPAKA_ASSERT_ACC(size <= warpSize * warpSize);
0102 ALPAKA_ASSERT_ACC(0 == blockDimension % warpSize);
0103 auto first = blockThreadIdx;
0104 auto laneId = blockThreadIdx & (warpSize - 1);
0105 auto warpUpRoundedSize = (size + warpSize - 1) / warpSize * warpSize;
0106
0107 for (auto i = first; i < warpUpRoundedSize; i += blockDimension) {
0108
0109 warpPrefixScan(acc, laneId, c, i, i < size);
0110 if (i < size) {
0111
0112 auto warpId = i / warpSize;
0113 ALPAKA_ASSERT_ACC(warpId < warpSize);
0114 if ((warpSize - 1) == laneId)
0115 ws[warpId] = c[i];
0116 }
0117 }
0118 alpaka::syncBlockThreads(acc);
0119 if (size <= warpSize)
0120 return;
0121 if (blockThreadIdx < warpSize) {
0122 warpPrefixScan(acc, laneId, ws, blockThreadIdx);
0123 }
0124 alpaka::syncBlockThreads(acc);
0125 for (auto i = first + warpSize; i < size; i += blockDimension) {
0126 auto warpId = i / warpSize;
0127 c[i] += ws[warpId - 1];
0128 }
0129 alpaka::syncBlockThreads(acc);
0130 } else {
0131 for (int32_t i = 1; i < size; ++i)
0132 c[i] += c[i - 1];
0133 }
0134 }
0135
0136
0137 template <typename T>
0138 struct multiBlockPrefixScan {
0139 template <typename TAcc>
0140 ALPAKA_FN_ACC void operator()(
0141 const TAcc& acc, T const* ci, T* co, uint32_t size, int32_t numBlocks, int32_t* pc, std::size_t warpSize) const {
0142
0143 T* ws = nullptr;
0144 if constexpr (!requires_single_thread_per_block_v<TAcc>) {
0145 ws = alpaka::getDynSharedMem<T>(acc);
0146 }
0147 ALPAKA_ASSERT_ACC(warpSize == static_cast<std::size_t>(alpaka::warp::getSize(acc)));
0148 [[maybe_unused]] const auto elementsPerGrid = alpaka::getWorkDiv<alpaka::Grid, alpaka::Elems>(acc)[0u];
0149 const auto elementsPerBlock = alpaka::getWorkDiv<alpaka::Block, alpaka::Elems>(acc)[0u];
0150 const auto threadsPerBlock = alpaka::getWorkDiv<alpaka::Block, alpaka::Threads>(acc)[0u];
0151 const auto blocksPerGrid = alpaka::getWorkDiv<alpaka::Grid, alpaka::Blocks>(acc)[0u];
0152 const auto blockIdx = alpaka::getIdx<alpaka::Grid, alpaka::Blocks>(acc)[0u];
0153 const auto threadIdx = alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0u];
0154 ALPAKA_ASSERT_ACC(elementsPerGrid >= size);
0155
0156 [[maybe_unused]] int off = elementsPerBlock * blockIdx;
0157 if (size - off > 0) {
0158 blockPrefixScan(acc, ci + off, co + off, std::min(elementsPerBlock, size - off), ws);
0159 }
0160
0161
0162 auto& isLastBlockDone = alpaka::declareSharedVar<bool, __COUNTER__>(acc);
0163
0164 if (0 == threadIdx) {
0165 alpaka::mem_fence(acc, alpaka::memory_scope::Device{});
0166 auto value = alpaka::atomicAdd(acc, pc, 1, alpaka::hierarchy::Blocks{});
0167 isLastBlockDone = (value == (int(blocksPerGrid) - 1));
0168 }
0169
0170 alpaka::syncBlockThreads(acc);
0171
0172 if (!isLastBlockDone)
0173 return;
0174
0175 ALPAKA_ASSERT_ACC(int(blocksPerGrid) == *pc);
0176
0177
0178
0179
0180 T* psum = nullptr;
0181 if constexpr (!requires_single_thread_per_block_v<TAcc>) {
0182 psum = ws + warpSize;
0183 } else {
0184 psum = alpaka::getDynSharedMem<T>(acc);
0185 }
0186 for (int32_t i = threadIdx, ni = blocksPerGrid; i < ni; i += threadsPerBlock) {
0187 auto j = elementsPerBlock * i + elementsPerBlock - 1;
0188 psum[i] = (j < size) ? co[j] : T(0);
0189 }
0190 alpaka::syncBlockThreads(acc);
0191 blockPrefixScan(acc, psum, psum, blocksPerGrid, ws);
0192
0193
0194
0195
0196 if constexpr (!requires_single_thread_per_block_v<TAcc>) {
0197
0198 for (uint32_t i = threadIdx + threadsPerBlock, k = 0; i < size; i += threadsPerBlock, ++k) {
0199 co[i] += psum[k];
0200 }
0201 } else {
0202
0203 for (uint32_t i = elementsPerBlock; i < size; i++) {
0204 co[i] += psum[i / elementsPerBlock - 1];
0205 }
0206 }
0207 }
0208 };
0209 }
0210
0211
0212 namespace alpaka::trait {
0213
0214 template <typename TAcc, typename T>
0215 struct BlockSharedMemDynSizeBytes<cms::alpakatools::multiBlockPrefixScan<T>, TAcc> {
0216 template <typename TVec>
0217 ALPAKA_FN_HOST_ACC static std::size_t getBlockSharedMemDynSizeBytes(
0218 cms::alpakatools::multiBlockPrefixScan<T> const& ,
0219 TVec const& ,
0220 TVec const& ,
0221 T const* ,
0222 T const* ,
0223 int32_t ,
0224 int32_t numBlocks,
0225 int32_t const* ,
0226
0227 std::size_t warpSize) {
0228
0229 if constexpr (cms::alpakatools::requires_single_thread_per_block_v<TAcc>) {
0230 return sizeof(T) * numBlocks;
0231 } else {
0232 return sizeof(T) * (warpSize + numBlocks);
0233 }
0234 }
0235 };
0236
0237 }
0238
0239 #endif