Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-08-15 01:07:42

0001 #ifndef HeterogeneousCore_AlpakaCore_interface_ScopedContext_h
0002 #define HeterogeneousCore_AlpakaCore_interface_ScopedContext_h
0003 
0004 #include <memory>
0005 #include <stdexcept>
0006 #include <utility>
0007 
0008 #include "DataFormats/Portable/interface/Product.h"
0009 #include "FWCore/Concurrency/interface/WaitingTaskWithArenaHolder.h"
0010 #include "FWCore/Framework/interface/Event.h"
0011 #include "FWCore/Utilities/interface/EDGetToken.h"
0012 #include "FWCore/Utilities/interface/EDPutToken.h"
0013 #include "HeterogeneousCore/AlpakaCore/interface/ContextState.h"
0014 #include "HeterogeneousCore/AlpakaCore/interface/EventCache.h"
0015 #include "HeterogeneousCore/AlpakaCore/interface/QueueCache.h"
0016 #include "HeterogeneousCore/AlpakaCore/interface/chooseDevice.h"
0017 #include "HeterogeneousCore/AlpakaInterface/interface/HostOnlyTask.h"
0018 #include "HeterogeneousCore/AlpakaInterface/interface/ScopedContextFwd.h"
0019 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0020 #include "HeterogeneousCore/AlpakaInterface/interface/traits.h"
0021 
0022 namespace cms::alpakatest {
0023   class TestScopedContext;
0024 }
0025 
0026 namespace cms::alpakatools {
0027 
0028   namespace impl {
0029     // This class is intended to be derived by other ScopedContext*, not for general use
0030     template <typename TQueue, typename>
0031     class ScopedContextBase {
0032     public:
0033       using Queue = TQueue;
0034       using Device = alpaka::Dev<Queue>;
0035       using Platform = alpaka::Pltf<Device>;
0036 
0037       Device device() const { return alpaka::getDev(*queue_); }
0038 
0039       Queue& queue() { return *queue_; }
0040       const std::shared_ptr<Queue>& queuePtr() const { return queue_; }
0041 
0042     protected:
0043       ScopedContextBase(ProductBase<Queue> const& data)
0044           : queue_{data.mayReuseQueue() ? data.queuePtr() : getQueueCache<Queue>().get(data.device())} {}
0045 
0046       explicit ScopedContextBase(std::shared_ptr<Queue> queue) : queue_(std::move(queue)) {}
0047 
0048       explicit ScopedContextBase(edm::StreamID streamID)
0049           : queue_{getQueueCache<Queue>().get(cms::alpakatools::chooseDevice<Platform>(streamID))} {}
0050 
0051     private:
0052       std::shared_ptr<Queue> queue_;
0053     };
0054 
0055     template <typename TQueue, typename>
0056     class ScopedContextGetterBase : public ScopedContextBase<TQueue> {
0057     public:
0058       using Queue = TQueue;
0059 
0060       template <typename T>
0061       const T& get(Product<Queue, T> const& data) {
0062         synchronizeStreams(data);
0063         return data.data_;
0064       }
0065 
0066       template <typename T>
0067       const T& get(edm::Event const& event, edm::EDGetTokenT<Product<Queue, T>> token) {
0068         return get(event.get(token));
0069       }
0070 
0071     protected:
0072       template <typename... Args>
0073       ScopedContextGetterBase(Args&&... args) : ScopedContextBase<Queue>{std::forward<Args>(args)...} {}
0074 
0075       void synchronizeStreams(ProductBase<Queue> const& data) {
0076         // If the product has been enqueued to a different queue, make sure that it is available before accessing it
0077         if (data.queue() != this->queue()) {
0078           // Different queues, check if the underlying device is the same
0079           if (data.device() != this->device()) {
0080             // Eventually replace with prefetch to current device (assuming unified memory works)
0081             // If we won't go to unified memory, need to figure out something else...
0082             throw cms::Exception("LogicError") << "Handling data from multiple devices is not yet supported";
0083           }
0084           // If the data product is not yet available, synchronize the two queues
0085           if (not data.isAvailable()) {
0086             // Event not yet occurred, so need to add synchronization
0087             // here. Sychronization is done by making the current queue
0088             // wait for an event, so all subsequent work in the queue
0089             // will run only after the event has "occurred" (i.e. data
0090             // product became available).
0091             alpaka::wait(this->queue(), data.event());
0092           }
0093         }
0094       }
0095     };
0096 
0097     class ScopedContextHolderHelper {
0098     public:
0099       ScopedContextHolderHelper(edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0100           : waitingTaskHolder_{std::move(waitingTaskHolder)} {}
0101 
0102       template <typename F, typename TQueue, typename = std::enable_if_t<cms::alpakatools::is_queue_v<TQueue>>>
0103       void pushNextTask(F&& f, ContextState<TQueue> const* state) {
0104         replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder{edm::make_waiting_task_with_holder(
0105             std::move(waitingTaskHolder_), [state, func = std::forward<F>(f)](edm::WaitingTaskWithArenaHolder h) {
0106               func(ScopedContextTask{state, std::move(h)});
0107             })});
0108       }
0109 
0110       void replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder waitingTaskHolder) {
0111         waitingTaskHolder_ = std::move(waitingTaskHolder);
0112       }
0113 
0114       template <typename TQueue, typename = std::enable_if_t<cms::alpakatools::is_queue_v<TQueue>>>
0115       void enqueueCallback(TQueue& queue) {
0116         alpaka::enqueue(queue, alpaka::HostOnlyTask([holder = std::move(waitingTaskHolder_)]() {
0117                           // The functor is required to be const, but the original waitingTaskHolder_
0118                           // needs to be notified...
0119                           const_cast<edm::WaitingTaskWithArenaHolder&>(holder).doneWaiting(nullptr);
0120                         }));
0121       }
0122 
0123     private:
0124       edm::WaitingTaskWithArenaHolder waitingTaskHolder_;
0125     };
0126   }  // namespace impl
0127 
0128   /**
0129    * The aim of this class is to do necessary per-event "initialization" in ExternalWork acquire():
0130    * - setting the current device
0131    * - calling edm::WaitingTaskWithArenaHolder::doneWaiting() when necessary
0132    * - synchronizing between queues if necessary
0133    * and enforce that those get done in a proper way in RAII fashion.
0134    */
0135   template <typename TQueue, typename>
0136   class ScopedContextAcquire : public impl::ScopedContextGetterBase<TQueue> {
0137   public:
0138     using Queue = TQueue;
0139     using ScopedContextGetterBase = impl::ScopedContextGetterBase<Queue>;
0140     using ScopedContextGetterBase::queue;
0141     using ScopedContextGetterBase::queuePtr;
0142 
0143     /// Constructor to create a new queue (no need for context beyond acquire())
0144     explicit ScopedContextAcquire(edm::StreamID streamID, edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0145         : ScopedContextGetterBase(streamID), holderHelper_{std::move(waitingTaskHolder)} {}
0146 
0147     /// Constructor to create a new queue, and the context is needed after acquire()
0148     explicit ScopedContextAcquire(edm::StreamID streamID,
0149                                   edm::WaitingTaskWithArenaHolder waitingTaskHolder,
0150                                   ContextState<Queue>& state)
0151         : ScopedContextGetterBase(streamID), holderHelper_{std::move(waitingTaskHolder)}, contextState_{&state} {}
0152 
0153     /// Constructor to (possibly) re-use a queue (no need for context beyond acquire())
0154     explicit ScopedContextAcquire(ProductBase<Queue> const& data, edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0155         : ScopedContextGetterBase(data), holderHelper_{std::move(waitingTaskHolder)} {}
0156 
0157     /// Constructor to (possibly) re-use a queue, and the context is needed after acquire()
0158     explicit ScopedContextAcquire(ProductBase<Queue> const& data,
0159                                   edm::WaitingTaskWithArenaHolder waitingTaskHolder,
0160                                   ContextState<Queue>& state)
0161         : ScopedContextGetterBase(data), holderHelper_{std::move(waitingTaskHolder)}, contextState_{&state} {}
0162 
0163     ~ScopedContextAcquire() {
0164       holderHelper_.enqueueCallback(queue());
0165       if (contextState_) {
0166         contextState_->set(queuePtr());
0167       }
0168     }
0169 
0170     template <typename F>
0171     void pushNextTask(F&& f) {
0172       if (contextState_ == nullptr)
0173         throwNoState();
0174       holderHelper_.pushNextTask(std::forward<F>(f), contextState_);
0175     }
0176 
0177     void replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder waitingTaskHolder) {
0178       holderHelper_.replaceWaitingTaskHolder(std::move(waitingTaskHolder));
0179     }
0180 
0181   private:
0182     void throwNoState() {
0183       throw cms::Exception("LogicError")
0184           << "Calling ScopedContextAcquire::insertNextTask() requires ScopedContextAcquire to be constructed with "
0185              "ContextState, but that was not the case";
0186     }
0187 
0188     impl::ScopedContextHolderHelper holderHelper_;
0189     ContextState<Queue>* contextState_ = nullptr;
0190   };
0191 
0192   /**
0193    * The aim of this class is to do necessary per-event "initialization" in ExternalWork produce() or normal produce():
0194    * - setting the current device
0195    * - synchronizing between queues if necessary
0196    * and enforce that those get done in a proper way in RAII fashion.
0197    */
0198   template <typename TQueue, typename>
0199   class ScopedContextProduce : public impl::ScopedContextGetterBase<TQueue> {
0200   public:
0201     using Queue = TQueue;
0202     using Event = alpaka::Event<Queue>;
0203     using ScopedContextGetterBase = impl::ScopedContextGetterBase<Queue>;
0204     using ScopedContextGetterBase::device;
0205     using ScopedContextGetterBase::queue;
0206     using ScopedContextGetterBase::queuePtr;
0207 
0208     /// Constructor to re-use the queue of acquire() (ExternalWork module)
0209     explicit ScopedContextProduce(ContextState<Queue>& state)
0210         : ScopedContextGetterBase(state.releaseQueuePtr()), event_{getEventCache<Event>().get(device())} {}
0211 
0212     explicit ScopedContextProduce(ProductBase<Queue> const& data)
0213         : ScopedContextGetterBase(data), event_{getEventCache<Event>().get(device())} {}
0214 
0215     explicit ScopedContextProduce(edm::StreamID streamID)
0216         : ScopedContextGetterBase(streamID), event_{getEventCache<Event>().get(device())} {}
0217 
0218     /// Record the event, all asynchronous work must have been queued before the destructor
0219     ~ScopedContextProduce() {
0220       // FIXME: this may throw an execption if the underlaying call fails.
0221       alpaka::enqueue(queue(), *event_);
0222     }
0223 
0224     template <typename T>
0225     std::unique_ptr<Product<Queue, T>> wrap(T data) {
0226       // make_unique doesn't work because of private constructor
0227       return std::unique_ptr<Product<Queue, T>>(new Product<Queue, T>(queuePtr(), std::move(data)));
0228     }
0229 
0230     template <typename T, typename... Args>
0231     auto emplace(edm::Event& iEvent, edm::EDPutTokenT<Product<Queue, T>> token, Args&&... args) {
0232       return iEvent.emplace(token, queuePtr(), event_, std::forward<Args>(args)...);
0233     }
0234 
0235   private:
0236     friend class ::cms::alpakatest::TestScopedContext;
0237 
0238     explicit ScopedContextProduce(std::shared_ptr<Queue> queue)
0239         : ScopedContextGetterBase(std::move(queue)), event_{getEventCache<Event>().get(device())} {}
0240 
0241     std::shared_ptr<Event> event_;
0242   };
0243 
0244   /**
0245    * The aim of this class is to do necessary per-task "initialization" tasks created in ExternalWork acquire():
0246    * - setting the current device
0247    * - calling edm::WaitingTaskWithArenaHolder::doneWaiting() when necessary
0248    * and enforce that those get done in a proper way in RAII fashion.
0249    */
0250   template <typename TQueue, typename>
0251   class ScopedContextTask : public impl::ScopedContextBase<TQueue> {
0252   public:
0253     using Queue = TQueue;
0254     using ScopedContextBase = impl::ScopedContextBase<Queue>;
0255     using ScopedContextBase::queue;
0256     using ScopedContextBase::queuePtr;
0257 
0258     /// Constructor to re-use the queue of acquire() (ExternalWork module)
0259     explicit ScopedContextTask(ContextState<Queue> const* state, edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0260         : ScopedContextBase(state->queuePtr()),  // don't move, state is re-used afterwards
0261           holderHelper_{std::move(waitingTaskHolder)},
0262           contextState_{state} {}
0263 
0264     ~ScopedContextTask() { holderHelper_.enqueueCallback(queue()); }
0265 
0266     template <typename F>
0267     void pushNextTask(F&& f) {
0268       holderHelper_.pushNextTask(std::forward<F>(f), contextState_);
0269     }
0270 
0271     void replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder waitingTaskHolder) {
0272       holderHelper_.replaceWaitingTaskHolder(std::move(waitingTaskHolder));
0273     }
0274 
0275   private:
0276     impl::ScopedContextHolderHelper holderHelper_;
0277     ContextState<Queue> const* contextState_;
0278   };
0279 
0280   /**
0281    * The aim of this class is to do necessary per-event "initialization" in analyze()
0282    * - setting the current device
0283    * - synchronizing between queues if necessary
0284    * and enforce that those get done in a proper way in RAII fashion.
0285    */
0286   template <typename TQueue, typename>
0287   class ScopedContextAnalyze : public impl::ScopedContextGetterBase<TQueue> {
0288   public:
0289     using Queue = TQueue;
0290     using ScopedContextGetterBase = impl::ScopedContextGetterBase<Queue>;
0291     using ScopedContextGetterBase::queue;
0292     using ScopedContextGetterBase::queuePtr;
0293 
0294     /// Constructor to (possibly) re-use a queue
0295     explicit ScopedContextAnalyze(ProductBase<Queue> const& data) : ScopedContextGetterBase(data) {}
0296   };
0297 
0298 }  // namespace cms::alpakatools
0299 
0300 #endif  // HeterogeneousCore_AlpakaCore_interface_ScopedContext_h