File indexing completed on 2023-02-06 03:09:08
0001 #ifndef HeterogeneousCore_AlpakaCore_interface_ScopedContext_h
0002 #define HeterogeneousCore_AlpakaCore_interface_ScopedContext_h
0003
0004 #include <memory>
0005 #include <stdexcept>
0006 #include <utility>
0007
0008 #include <alpaka/alpaka.hpp>
0009
0010 #include "DataFormats/Portable/interface/Product.h"
0011 #include "FWCore/Concurrency/interface/WaitingTaskWithArenaHolder.h"
0012 #include "FWCore/Framework/interface/Event.h"
0013 #include "FWCore/Utilities/interface/EDGetToken.h"
0014 #include "FWCore/Utilities/interface/EDPutToken.h"
0015 #include "HeterogeneousCore/AlpakaCore/interface/ContextState.h"
0016 #include "HeterogeneousCore/AlpakaCore/interface/EventCache.h"
0017 #include "HeterogeneousCore/AlpakaCore/interface/QueueCache.h"
0018 #include "HeterogeneousCore/AlpakaCore/interface/chooseDevice.h"
0019 #include "HeterogeneousCore/AlpakaInterface/interface/HostOnlyTask.h"
0020 #include "HeterogeneousCore/AlpakaInterface/interface/ScopedContextFwd.h"
0021 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0022 #include "HeterogeneousCore/AlpakaInterface/interface/traits.h"
0023
0024 namespace cms::alpakatest {
0025 class TestScopedContext;
0026 }
0027
0028 namespace cms::alpakatools {
0029
0030 namespace impl {
0031
0032 template <typename TQueue, typename>
0033 class ScopedContextBase {
0034 public:
0035 using Queue = TQueue;
0036 using Device = alpaka::Dev<Queue>;
0037 using Platform = alpaka::Pltf<Device>;
0038
0039 Device device() const { return alpaka::getDev(*queue_); }
0040
0041 Queue& queue() { return *queue_; }
0042 const std::shared_ptr<Queue>& queuePtr() const { return queue_; }
0043
0044 protected:
0045 ScopedContextBase(ProductBase<Queue> const& data)
0046 : queue_{data.mayReuseQueue() ? data.queuePtr() : getQueueCache<Queue>().get(data.device())} {}
0047
0048 explicit ScopedContextBase(std::shared_ptr<Queue> queue) : queue_(std::move(queue)) {}
0049
0050 explicit ScopedContextBase(edm::StreamID streamID)
0051 : queue_{getQueueCache<Queue>().get(cms::alpakatools::chooseDevice<Platform>(streamID))} {}
0052
0053 private:
0054 std::shared_ptr<Queue> queue_;
0055 };
0056
0057 template <typename TQueue, typename>
0058 class ScopedContextGetterBase : public ScopedContextBase<TQueue> {
0059 public:
0060 using Queue = TQueue;
0061
0062 template <typename T>
0063 const T& get(Product<Queue, T> const& data) {
0064 synchronizeStreams(data);
0065 return data.data_;
0066 }
0067
0068 template <typename T>
0069 const T& get(edm::Event const& event, edm::EDGetTokenT<Product<Queue, T>> token) {
0070 return get(event.get(token));
0071 }
0072
0073 protected:
0074 template <typename... Args>
0075 ScopedContextGetterBase(Args&&... args) : ScopedContextBase<Queue>{std::forward<Args>(args)...} {}
0076
0077 void synchronizeStreams(ProductBase<Queue> const& data) {
0078
0079 if (data.queue() != this->queue()) {
0080
0081 if (data.device() != this->device()) {
0082
0083
0084 throw cms::Exception("LogicError") << "Handling data from multiple devices is not yet supported";
0085 }
0086
0087 if (not data.isAvailable()) {
0088
0089
0090
0091
0092
0093 alpaka::wait(this->queue(), data.event());
0094 }
0095 }
0096 }
0097 };
0098
0099 class ScopedContextHolderHelper {
0100 public:
0101 ScopedContextHolderHelper(edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0102 : waitingTaskHolder_{std::move(waitingTaskHolder)} {}
0103
0104 template <typename F, typename TQueue, typename = std::enable_if_t<alpaka::isQueue<TQueue>>>
0105 void pushNextTask(F&& f, ContextState<TQueue> const* state) {
0106 replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder{edm::make_waiting_task_with_holder(
0107 std::move(waitingTaskHolder_), [state, func = std::forward<F>(f)](edm::WaitingTaskWithArenaHolder h) {
0108 func(ScopedContextTask{state, std::move(h)});
0109 })});
0110 }
0111
0112 void replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder waitingTaskHolder) {
0113 waitingTaskHolder_ = std::move(waitingTaskHolder);
0114 }
0115
0116 template <typename TQueue, typename = std::enable_if_t<alpaka::isQueue<TQueue>>>
0117 void enqueueCallback(TQueue& queue) {
0118 alpaka::enqueue(queue, alpaka::HostOnlyTask([holder = std::move(waitingTaskHolder_)]() {
0119
0120
0121 const_cast<edm::WaitingTaskWithArenaHolder&>(holder).doneWaiting(nullptr);
0122 }));
0123 }
0124
0125 private:
0126 edm::WaitingTaskWithArenaHolder waitingTaskHolder_;
0127 };
0128 }
0129
0130
0131
0132
0133
0134
0135
0136
0137 template <typename TQueue, typename>
0138 class ScopedContextAcquire : public impl::ScopedContextGetterBase<TQueue> {
0139 public:
0140 using Queue = TQueue;
0141 using ScopedContextGetterBase = impl::ScopedContextGetterBase<Queue>;
0142 using ScopedContextGetterBase::queue;
0143 using ScopedContextGetterBase::queuePtr;
0144
0145
0146 explicit ScopedContextAcquire(edm::StreamID streamID, edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0147 : ScopedContextGetterBase(streamID), holderHelper_{std::move(waitingTaskHolder)} {}
0148
0149
0150 explicit ScopedContextAcquire(edm::StreamID streamID,
0151 edm::WaitingTaskWithArenaHolder waitingTaskHolder,
0152 ContextState<Queue>& state)
0153 : ScopedContextGetterBase(streamID), holderHelper_{std::move(waitingTaskHolder)}, contextState_{&state} {}
0154
0155
0156 explicit ScopedContextAcquire(ProductBase<Queue> const& data, edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0157 : ScopedContextGetterBase(data), holderHelper_{std::move(waitingTaskHolder)} {}
0158
0159
0160 explicit ScopedContextAcquire(ProductBase<Queue> const& data,
0161 edm::WaitingTaskWithArenaHolder waitingTaskHolder,
0162 ContextState<Queue>& state)
0163 : ScopedContextGetterBase(data), holderHelper_{std::move(waitingTaskHolder)}, contextState_{&state} {}
0164
0165 ~ScopedContextAcquire() {
0166 holderHelper_.enqueueCallback(queue());
0167 if (contextState_) {
0168 contextState_->set(queuePtr());
0169 }
0170 }
0171
0172 template <typename F>
0173 void pushNextTask(F&& f) {
0174 if (contextState_ == nullptr)
0175 throwNoState();
0176 holderHelper_.pushNextTask(std::forward<F>(f), contextState_);
0177 }
0178
0179 void replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder waitingTaskHolder) {
0180 holderHelper_.replaceWaitingTaskHolder(std::move(waitingTaskHolder));
0181 }
0182
0183 private:
0184 void throwNoState() {
0185 throw cms::Exception("LogicError")
0186 << "Calling ScopedContextAcquire::insertNextTask() requires ScopedContextAcquire to be constructed with "
0187 "ContextState, but that was not the case";
0188 }
0189
0190 impl::ScopedContextHolderHelper holderHelper_;
0191 ContextState<Queue>* contextState_ = nullptr;
0192 };
0193
0194
0195
0196
0197
0198
0199
0200 template <typename TQueue, typename>
0201 class ScopedContextProduce : public impl::ScopedContextGetterBase<TQueue> {
0202 public:
0203 using Queue = TQueue;
0204 using Event = alpaka::Event<Queue>;
0205 using ScopedContextGetterBase = impl::ScopedContextGetterBase<Queue>;
0206 using ScopedContextGetterBase::device;
0207 using ScopedContextGetterBase::queue;
0208 using ScopedContextGetterBase::queuePtr;
0209
0210
0211 explicit ScopedContextProduce(ContextState<Queue>& state)
0212 : ScopedContextGetterBase(state.releaseQueuePtr()), event_{getEventCache<Event>().get(device())} {}
0213
0214 explicit ScopedContextProduce(ProductBase<Queue> const& data)
0215 : ScopedContextGetterBase(data), event_{getEventCache<Event>().get(device())} {}
0216
0217 explicit ScopedContextProduce(edm::StreamID streamID)
0218 : ScopedContextGetterBase(streamID), event_{getEventCache<Event>().get(device())} {}
0219
0220
0221 ~ScopedContextProduce() {
0222
0223 alpaka::enqueue(queue(), *event_);
0224 }
0225
0226 template <typename T>
0227 std::unique_ptr<Product<Queue, T>> wrap(T data) {
0228
0229 return std::unique_ptr<Product<Queue, T>>(new Product<Queue, T>(queuePtr(), std::move(data)));
0230 }
0231
0232 template <typename T, typename... Args>
0233 auto emplace(edm::Event& iEvent, edm::EDPutTokenT<Product<Queue, T>> token, Args&&... args) {
0234 return iEvent.emplace(token, queuePtr(), event_, std::forward<Args>(args)...);
0235 }
0236
0237 private:
0238 friend class ::cms::alpakatest::TestScopedContext;
0239
0240 explicit ScopedContextProduce(std::shared_ptr<Queue> queue)
0241 : ScopedContextGetterBase(std::move(queue)), event_{getEventCache<Event>().get(device())} {}
0242
0243 std::shared_ptr<Event> event_;
0244 };
0245
0246
0247
0248
0249
0250
0251
0252 template <typename TQueue, typename>
0253 class ScopedContextTask : public impl::ScopedContextBase<TQueue> {
0254 public:
0255 using Queue = TQueue;
0256 using ScopedContextBase = impl::ScopedContextBase<Queue>;
0257 using ScopedContextBase::queue;
0258 using ScopedContextBase::queuePtr;
0259
0260
0261 explicit ScopedContextTask(ContextState<Queue> const* state, edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0262 : ScopedContextBase(state->queuePtr()),
0263 holderHelper_{std::move(waitingTaskHolder)},
0264 contextState_{state} {}
0265
0266 ~ScopedContextTask() { holderHelper_.enqueueCallback(queue()); }
0267
0268 template <typename F>
0269 void pushNextTask(F&& f) {
0270 holderHelper_.pushNextTask(std::forward<F>(f), contextState_);
0271 }
0272
0273 void replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder waitingTaskHolder) {
0274 holderHelper_.replaceWaitingTaskHolder(std::move(waitingTaskHolder));
0275 }
0276
0277 private:
0278 impl::ScopedContextHolderHelper holderHelper_;
0279 ContextState<Queue> const* contextState_;
0280 };
0281
0282
0283
0284
0285
0286
0287
0288 template <typename TQueue, typename>
0289 class ScopedContextAnalyze : public impl::ScopedContextGetterBase<TQueue> {
0290 public:
0291 using Queue = TQueue;
0292 using ScopedContextGetterBase = impl::ScopedContextGetterBase<Queue>;
0293 using ScopedContextGetterBase::queue;
0294 using ScopedContextGetterBase::queuePtr;
0295
0296
0297 explicit ScopedContextAnalyze(ProductBase<Queue> const& data) : ScopedContextGetterBase(data) {}
0298 };
0299
0300 }
0301
0302 #endif