File indexing completed on 2023-01-25 02:56:28
0001 #ifndef HeterogeneousCore_AlpakaInterface_interface_workdivision_h
0002 #define HeterogeneousCore_AlpakaInterface_interface_workdivision_h
0003
0004 #include <type_traits>
0005
0006 #include <alpaka/alpaka.hpp>
0007
0008 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0009 #include "HeterogeneousCore/AlpakaInterface/interface/traits.h"
0010 #include "HeterogeneousCore/AlpakaInterface/interface/vec.h"
0011
0012 namespace cms::alpakatools {
0013
0014 using namespace alpaka_common;
0015
0016
0017 inline constexpr Idx round_up_by(Idx value, Idx divisor) { return (value + divisor - 1) / divisor * divisor; }
0018
0019
0020 inline constexpr Idx divide_up_by(Idx value, Idx divisor) { return (value + divisor - 1) / divisor; }
0021
0022
0023 template <typename TAcc, typename = std::enable_if_t<cms::alpakatools::is_accelerator_v<TAcc>>>
0024 struct requires_single_thread_per_block : public std::true_type {};
0025
0026 #ifdef ALPAKA_ACC_GPU_CUDA_ENABLED
0027 template <typename TDim>
0028 struct requires_single_thread_per_block<alpaka::AccGpuCudaRt<TDim, Idx>> : public std::false_type {};
0029 #endif
0030
0031 #ifdef ALPAKA_ACC_GPU_HIP_ENABLED
0032 template <typename TDim>
0033 struct requires_single_thread_per_block<alpaka::AccGpuHipRt<TDim, Idx>> : public std::false_type {};
0034 #endif
0035
0036
0037 template <typename TAcc, typename = std::enable_if_t<cms::alpakatools::is_accelerator_v<TAcc>>>
0038 inline constexpr bool requires_single_thread_per_block_v = requires_single_thread_per_block<TAcc>::value;
0039
0040
0041 template <typename TAcc,
0042 typename = std::enable_if_t<cms::alpakatools::is_accelerator_v<TAcc> and alpaka::Dim<TAcc>::value == 1>>
0043 inline WorkDiv<Dim1D> make_workdiv(Idx blocks, Idx elements) {
0044 if constexpr (not requires_single_thread_per_block_v<TAcc>) {
0045
0046
0047
0048 return WorkDiv<Dim1D>(blocks, elements, Idx{1});
0049 } else {
0050
0051
0052
0053 return WorkDiv<Dim1D>(blocks, Idx{1}, elements);
0054 }
0055 }
0056
0057
0058 template <typename TAcc, typename = std::enable_if_t<cms::alpakatools::is_accelerator_v<TAcc>>>
0059 inline WorkDiv<alpaka::Dim<TAcc>> make_workdiv(const Vec<alpaka::Dim<TAcc>>& blocks,
0060 const Vec<alpaka::Dim<TAcc>>& elements) {
0061 using Dim = alpaka::Dim<TAcc>;
0062 if constexpr (not requires_single_thread_per_block_v<TAcc>) {
0063
0064
0065
0066 return WorkDiv<Dim>(blocks, elements, Vec<Dim>::ones());
0067 } else {
0068
0069
0070
0071 return WorkDiv<Dim>(blocks, Vec<Dim>::ones(), elements);
0072 }
0073 }
0074
0075 template <typename TAcc,
0076 typename = std::enable_if_t<cms::alpakatools::is_accelerator_v<TAcc> and alpaka::Dim<TAcc>::value == 1>>
0077 class elements_with_stride {
0078 public:
0079 ALPAKA_FN_ACC inline elements_with_stride(TAcc const& acc)
0080 : elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u]},
0081 first_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
0082 stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
0083 extent_{stride_} {}
0084
0085 ALPAKA_FN_ACC inline elements_with_stride(TAcc const& acc, Idx extent)
0086 : elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u]},
0087 first_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
0088 stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
0089 extent_{extent} {}
0090
0091 class iterator {
0092 friend class elements_with_stride;
0093
0094 ALPAKA_FN_ACC inline iterator(Idx elements, Idx stride, Idx extent, Idx first)
0095 : elements_{elements},
0096 stride_{stride},
0097 extent_{extent},
0098 first_{std::min(first, extent)},
0099 index_{first_},
0100 last_{std::min(first + elements, extent)} {}
0101
0102 public:
0103 ALPAKA_FN_ACC inline Idx operator*() const { return index_; }
0104
0105
0106 ALPAKA_FN_ACC inline iterator& operator++() {
0107 if constexpr (requires_single_thread_per_block_v<TAcc>) {
0108
0109 ++index_;
0110 if (index_ < last_)
0111 return *this;
0112 }
0113
0114
0115 first_ += stride_;
0116 index_ = first_;
0117 last_ = std::min(first_ + elements_, extent_);
0118 if (index_ < extent_)
0119 return *this;
0120
0121
0122 first_ = extent_;
0123 index_ = extent_;
0124 last_ = extent_;
0125 return *this;
0126 }
0127
0128
0129 ALPAKA_FN_ACC inline iterator operator++(int) {
0130 iterator old = *this;
0131 ++(*this);
0132 return old;
0133 }
0134
0135 ALPAKA_FN_ACC inline bool operator==(iterator const& other) const {
0136 return (index_ == other.index_) and (first_ == other.first_);
0137 }
0138
0139 ALPAKA_FN_ACC inline bool operator!=(iterator const& other) const { return not(*this == other); }
0140
0141 private:
0142
0143 Idx elements_;
0144 Idx stride_;
0145 Idx extent_;
0146
0147 Idx first_;
0148 Idx index_;
0149 Idx last_;
0150 };
0151
0152 ALPAKA_FN_ACC inline iterator begin() const { return iterator(elements_, stride_, extent_, first_); }
0153
0154 ALPAKA_FN_ACC inline iterator end() const { return iterator(elements_, stride_, extent_, extent_); }
0155
0156 private:
0157 const Idx elements_;
0158 const Idx first_;
0159 const Idx stride_;
0160 const Idx extent_;
0161 };
0162
0163 template <typename TAcc,
0164 typename = std::enable_if_t<cms::alpakatools::is_accelerator_v<TAcc> and (alpaka::Dim<TAcc>::value > 0)>>
0165 class elements_with_stride_nd {
0166 public:
0167 using Dim = alpaka::Dim<TAcc>;
0168 using Vec = alpaka::Vec<Dim, Idx>;
0169
0170 ALPAKA_FN_ACC inline elements_with_stride_nd(TAcc const& acc)
0171 : elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)},
0172 first_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc) * elements_},
0173 stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc) * elements_},
0174 extent_{stride_} {}
0175
0176 ALPAKA_FN_ACC inline elements_with_stride_nd(TAcc const& acc, Vec extent)
0177 : elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)},
0178 first_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc) * elements_},
0179 stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc) * elements_},
0180 extent_{extent} {}
0181
0182 class iterator {
0183 friend class elements_with_stride_nd;
0184
0185 public:
0186 ALPAKA_FN_ACC inline Vec operator*() const { return index_; }
0187
0188
0189 ALPAKA_FN_ACC constexpr inline iterator operator++() {
0190 increment();
0191 return *this;
0192 }
0193
0194
0195 ALPAKA_FN_ACC constexpr inline iterator operator++(int) {
0196 iterator old = *this;
0197 increment();
0198 return old;
0199 }
0200
0201 ALPAKA_FN_ACC constexpr inline bool operator==(iterator const& other) const { return (index_ == other.index_); }
0202
0203 ALPAKA_FN_ACC constexpr inline bool operator!=(iterator const& other) const { return not(*this == other); }
0204
0205 private:
0206
0207 ALPAKA_FN_ACC inline iterator(elements_with_stride_nd const* loop, Vec first)
0208 : loop_{loop},
0209 thread_{alpaka::elementwise_min(first, loop->extent_)},
0210 range_{alpaka::elementwise_min(first + loop->elements_, loop->extent_)},
0211 index_{thread_} {}
0212
0213 template <size_t I>
0214 ALPAKA_FN_ACC inline constexpr bool nth_elements_loop() {
0215 bool overflow = false;
0216 ++index_[I];
0217 if (index_[I] >= range_[I]) {
0218 index_[I] = thread_[I];
0219 overflow = true;
0220 }
0221 return overflow;
0222 }
0223
0224 template <size_t N>
0225 ALPAKA_FN_ACC inline constexpr bool do_elements_loops() {
0226 if constexpr (N == 0) {
0227
0228 return true;
0229 } else {
0230 if (not nth_elements_loop<N - 1>()) {
0231 return false;
0232 } else {
0233 return do_elements_loops<N - 1>();
0234 }
0235 }
0236 }
0237
0238 template <size_t I>
0239 ALPAKA_FN_ACC inline constexpr bool nth_strided_loop() {
0240 bool overflow = false;
0241 thread_[I] += loop_->stride_[I];
0242 if (thread_[I] >= loop_->extent_[I]) {
0243 thread_[I] = loop_->first_[I];
0244 overflow = true;
0245 }
0246 index_[I] = thread_[I];
0247 range_[I] = std::min(thread_[I] + loop_->elements_[I], loop_->extent_[I]);
0248 return overflow;
0249 }
0250
0251 template <size_t N>
0252 ALPAKA_FN_ACC inline constexpr bool do_strided_loops() {
0253 if constexpr (N == 0) {
0254
0255 return true;
0256 } else {
0257 if (not nth_strided_loop<N - 1>()) {
0258 return false;
0259 } else {
0260 return do_strided_loops<N - 1>();
0261 }
0262 }
0263 }
0264
0265
0266 ALPAKA_FN_ACC inline constexpr void increment() {
0267 if constexpr (requires_single_thread_per_block_v<TAcc>) {
0268
0269
0270 if (not do_elements_loops<Dim::value>()) {
0271
0272 return;
0273 }
0274 }
0275
0276
0277
0278 if (not do_strided_loops<Dim::value>()) {
0279
0280 return;
0281 }
0282
0283
0284 thread_ = loop_->extent_;
0285 range_ = loop_->extent_;
0286 index_ = loop_->extent_;
0287 }
0288
0289
0290 const elements_with_stride_nd* loop_;
0291
0292
0293 Vec thread_;
0294 Vec range_;
0295 Vec index_;
0296 };
0297
0298 ALPAKA_FN_ACC inline iterator begin() const { return iterator{this, first_}; }
0299
0300 ALPAKA_FN_ACC inline iterator end() const { return iterator{this, extent_}; }
0301
0302 private:
0303 const Vec elements_;
0304 const Vec first_;
0305 const Vec stride_;
0306 const Vec extent_;
0307 };
0308
0309 }
0310
0311 #endif