Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-02-06 03:09:08

0001 #ifndef HeterogeneousCore_AlpakaCore_interface_ScopedContext_h
0002 #define HeterogeneousCore_AlpakaCore_interface_ScopedContext_h
0003 
0004 #include <memory>
0005 #include <stdexcept>
0006 #include <utility>
0007 
0008 #include <alpaka/alpaka.hpp>
0009 
0010 #include "DataFormats/Portable/interface/Product.h"
0011 #include "FWCore/Concurrency/interface/WaitingTaskWithArenaHolder.h"
0012 #include "FWCore/Framework/interface/Event.h"
0013 #include "FWCore/Utilities/interface/EDGetToken.h"
0014 #include "FWCore/Utilities/interface/EDPutToken.h"
0015 #include "HeterogeneousCore/AlpakaCore/interface/ContextState.h"
0016 #include "HeterogeneousCore/AlpakaCore/interface/EventCache.h"
0017 #include "HeterogeneousCore/AlpakaCore/interface/QueueCache.h"
0018 #include "HeterogeneousCore/AlpakaCore/interface/chooseDevice.h"
0019 #include "HeterogeneousCore/AlpakaInterface/interface/HostOnlyTask.h"
0020 #include "HeterogeneousCore/AlpakaInterface/interface/ScopedContextFwd.h"
0021 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0022 #include "HeterogeneousCore/AlpakaInterface/interface/traits.h"
0023 
0024 namespace cms::alpakatest {
0025   class TestScopedContext;
0026 }
0027 
0028 namespace cms::alpakatools {
0029 
0030   namespace impl {
0031     // This class is intended to be derived by other ScopedContext*, not for general use
0032     template <typename TQueue, typename>
0033     class ScopedContextBase {
0034     public:
0035       using Queue = TQueue;
0036       using Device = alpaka::Dev<Queue>;
0037       using Platform = alpaka::Pltf<Device>;
0038 
0039       Device device() const { return alpaka::getDev(*queue_); }
0040 
0041       Queue& queue() { return *queue_; }
0042       const std::shared_ptr<Queue>& queuePtr() const { return queue_; }
0043 
0044     protected:
0045       ScopedContextBase(ProductBase<Queue> const& data)
0046           : queue_{data.mayReuseQueue() ? data.queuePtr() : getQueueCache<Queue>().get(data.device())} {}
0047 
0048       explicit ScopedContextBase(std::shared_ptr<Queue> queue) : queue_(std::move(queue)) {}
0049 
0050       explicit ScopedContextBase(edm::StreamID streamID)
0051           : queue_{getQueueCache<Queue>().get(cms::alpakatools::chooseDevice<Platform>(streamID))} {}
0052 
0053     private:
0054       std::shared_ptr<Queue> queue_;
0055     };
0056 
0057     template <typename TQueue, typename>
0058     class ScopedContextGetterBase : public ScopedContextBase<TQueue> {
0059     public:
0060       using Queue = TQueue;
0061 
0062       template <typename T>
0063       const T& get(Product<Queue, T> const& data) {
0064         synchronizeStreams(data);
0065         return data.data_;
0066       }
0067 
0068       template <typename T>
0069       const T& get(edm::Event const& event, edm::EDGetTokenT<Product<Queue, T>> token) {
0070         return get(event.get(token));
0071       }
0072 
0073     protected:
0074       template <typename... Args>
0075       ScopedContextGetterBase(Args&&... args) : ScopedContextBase<Queue>{std::forward<Args>(args)...} {}
0076 
0077       void synchronizeStreams(ProductBase<Queue> const& data) {
0078         // If the product has been enqueued to a different queue, make sure that it is available before accessing it
0079         if (data.queue() != this->queue()) {
0080           // Different queues, check if the underlying device is the same
0081           if (data.device() != this->device()) {
0082             // Eventually replace with prefetch to current device (assuming unified memory works)
0083             // If we won't go to unified memory, need to figure out something else...
0084             throw cms::Exception("LogicError") << "Handling data from multiple devices is not yet supported";
0085           }
0086           // If the data product is not yet available, synchronize the two queues
0087           if (not data.isAvailable()) {
0088             // Event not yet occurred, so need to add synchronization
0089             // here. Sychronization is done by making the current queue
0090             // wait for an event, so all subsequent work in the queue
0091             // will run only after the event has "occurred" (i.e. data
0092             // product became available).
0093             alpaka::wait(this->queue(), data.event());
0094           }
0095         }
0096       }
0097     };
0098 
0099     class ScopedContextHolderHelper {
0100     public:
0101       ScopedContextHolderHelper(edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0102           : waitingTaskHolder_{std::move(waitingTaskHolder)} {}
0103 
0104       template <typename F, typename TQueue, typename = std::enable_if_t<alpaka::isQueue<TQueue>>>
0105       void pushNextTask(F&& f, ContextState<TQueue> const* state) {
0106         replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder{edm::make_waiting_task_with_holder(
0107             std::move(waitingTaskHolder_), [state, func = std::forward<F>(f)](edm::WaitingTaskWithArenaHolder h) {
0108               func(ScopedContextTask{state, std::move(h)});
0109             })});
0110       }
0111 
0112       void replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder waitingTaskHolder) {
0113         waitingTaskHolder_ = std::move(waitingTaskHolder);
0114       }
0115 
0116       template <typename TQueue, typename = std::enable_if_t<alpaka::isQueue<TQueue>>>
0117       void enqueueCallback(TQueue& queue) {
0118         alpaka::enqueue(queue, alpaka::HostOnlyTask([holder = std::move(waitingTaskHolder_)]() {
0119                           // The functor is required to be const, but the original waitingTaskHolder_
0120                           // needs to be notified...
0121                           const_cast<edm::WaitingTaskWithArenaHolder&>(holder).doneWaiting(nullptr);
0122                         }));
0123       }
0124 
0125     private:
0126       edm::WaitingTaskWithArenaHolder waitingTaskHolder_;
0127     };
0128   }  // namespace impl
0129 
0130   /**
0131    * The aim of this class is to do necessary per-event "initialization" in ExternalWork acquire():
0132    * - setting the current device
0133    * - calling edm::WaitingTaskWithArenaHolder::doneWaiting() when necessary
0134    * - synchronizing between queues if necessary
0135    * and enforce that those get done in a proper way in RAII fashion.
0136    */
0137   template <typename TQueue, typename>
0138   class ScopedContextAcquire : public impl::ScopedContextGetterBase<TQueue> {
0139   public:
0140     using Queue = TQueue;
0141     using ScopedContextGetterBase = impl::ScopedContextGetterBase<Queue>;
0142     using ScopedContextGetterBase::queue;
0143     using ScopedContextGetterBase::queuePtr;
0144 
0145     /// Constructor to create a new queue (no need for context beyond acquire())
0146     explicit ScopedContextAcquire(edm::StreamID streamID, edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0147         : ScopedContextGetterBase(streamID), holderHelper_{std::move(waitingTaskHolder)} {}
0148 
0149     /// Constructor to create a new queue, and the context is needed after acquire()
0150     explicit ScopedContextAcquire(edm::StreamID streamID,
0151                                   edm::WaitingTaskWithArenaHolder waitingTaskHolder,
0152                                   ContextState<Queue>& state)
0153         : ScopedContextGetterBase(streamID), holderHelper_{std::move(waitingTaskHolder)}, contextState_{&state} {}
0154 
0155     /// Constructor to (possibly) re-use a queue (no need for context beyond acquire())
0156     explicit ScopedContextAcquire(ProductBase<Queue> const& data, edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0157         : ScopedContextGetterBase(data), holderHelper_{std::move(waitingTaskHolder)} {}
0158 
0159     /// Constructor to (possibly) re-use a queue, and the context is needed after acquire()
0160     explicit ScopedContextAcquire(ProductBase<Queue> const& data,
0161                                   edm::WaitingTaskWithArenaHolder waitingTaskHolder,
0162                                   ContextState<Queue>& state)
0163         : ScopedContextGetterBase(data), holderHelper_{std::move(waitingTaskHolder)}, contextState_{&state} {}
0164 
0165     ~ScopedContextAcquire() {
0166       holderHelper_.enqueueCallback(queue());
0167       if (contextState_) {
0168         contextState_->set(queuePtr());
0169       }
0170     }
0171 
0172     template <typename F>
0173     void pushNextTask(F&& f) {
0174       if (contextState_ == nullptr)
0175         throwNoState();
0176       holderHelper_.pushNextTask(std::forward<F>(f), contextState_);
0177     }
0178 
0179     void replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder waitingTaskHolder) {
0180       holderHelper_.replaceWaitingTaskHolder(std::move(waitingTaskHolder));
0181     }
0182 
0183   private:
0184     void throwNoState() {
0185       throw cms::Exception("LogicError")
0186           << "Calling ScopedContextAcquire::insertNextTask() requires ScopedContextAcquire to be constructed with "
0187              "ContextState, but that was not the case";
0188     }
0189 
0190     impl::ScopedContextHolderHelper holderHelper_;
0191     ContextState<Queue>* contextState_ = nullptr;
0192   };
0193 
0194   /**
0195    * The aim of this class is to do necessary per-event "initialization" in ExternalWork produce() or normal produce():
0196    * - setting the current device
0197    * - synchronizing between queues if necessary
0198    * and enforce that those get done in a proper way in RAII fashion.
0199    */
0200   template <typename TQueue, typename>
0201   class ScopedContextProduce : public impl::ScopedContextGetterBase<TQueue> {
0202   public:
0203     using Queue = TQueue;
0204     using Event = alpaka::Event<Queue>;
0205     using ScopedContextGetterBase = impl::ScopedContextGetterBase<Queue>;
0206     using ScopedContextGetterBase::device;
0207     using ScopedContextGetterBase::queue;
0208     using ScopedContextGetterBase::queuePtr;
0209 
0210     /// Constructor to re-use the queue of acquire() (ExternalWork module)
0211     explicit ScopedContextProduce(ContextState<Queue>& state)
0212         : ScopedContextGetterBase(state.releaseQueuePtr()), event_{getEventCache<Event>().get(device())} {}
0213 
0214     explicit ScopedContextProduce(ProductBase<Queue> const& data)
0215         : ScopedContextGetterBase(data), event_{getEventCache<Event>().get(device())} {}
0216 
0217     explicit ScopedContextProduce(edm::StreamID streamID)
0218         : ScopedContextGetterBase(streamID), event_{getEventCache<Event>().get(device())} {}
0219 
0220     /// Record the event, all asynchronous work must have been queued before the destructor
0221     ~ScopedContextProduce() {
0222       // FIXME: this may throw an execption if the underlaying call fails.
0223       alpaka::enqueue(queue(), *event_);
0224     }
0225 
0226     template <typename T>
0227     std::unique_ptr<Product<Queue, T>> wrap(T data) {
0228       // make_unique doesn't work because of private constructor
0229       return std::unique_ptr<Product<Queue, T>>(new Product<Queue, T>(queuePtr(), std::move(data)));
0230     }
0231 
0232     template <typename T, typename... Args>
0233     auto emplace(edm::Event& iEvent, edm::EDPutTokenT<Product<Queue, T>> token, Args&&... args) {
0234       return iEvent.emplace(token, queuePtr(), event_, std::forward<Args>(args)...);
0235     }
0236 
0237   private:
0238     friend class ::cms::alpakatest::TestScopedContext;
0239 
0240     explicit ScopedContextProduce(std::shared_ptr<Queue> queue)
0241         : ScopedContextGetterBase(std::move(queue)), event_{getEventCache<Event>().get(device())} {}
0242 
0243     std::shared_ptr<Event> event_;
0244   };
0245 
0246   /**
0247    * The aim of this class is to do necessary per-task "initialization" tasks created in ExternalWork acquire():
0248    * - setting the current device
0249    * - calling edm::WaitingTaskWithArenaHolder::doneWaiting() when necessary
0250    * and enforce that those get done in a proper way in RAII fashion.
0251    */
0252   template <typename TQueue, typename>
0253   class ScopedContextTask : public impl::ScopedContextBase<TQueue> {
0254   public:
0255     using Queue = TQueue;
0256     using ScopedContextBase = impl::ScopedContextBase<Queue>;
0257     using ScopedContextBase::queue;
0258     using ScopedContextBase::queuePtr;
0259 
0260     /// Constructor to re-use the queue of acquire() (ExternalWork module)
0261     explicit ScopedContextTask(ContextState<Queue> const* state, edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0262         : ScopedContextBase(state->queuePtr()),  // don't move, state is re-used afterwards
0263           holderHelper_{std::move(waitingTaskHolder)},
0264           contextState_{state} {}
0265 
0266     ~ScopedContextTask() { holderHelper_.enqueueCallback(queue()); }
0267 
0268     template <typename F>
0269     void pushNextTask(F&& f) {
0270       holderHelper_.pushNextTask(std::forward<F>(f), contextState_);
0271     }
0272 
0273     void replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder waitingTaskHolder) {
0274       holderHelper_.replaceWaitingTaskHolder(std::move(waitingTaskHolder));
0275     }
0276 
0277   private:
0278     impl::ScopedContextHolderHelper holderHelper_;
0279     ContextState<Queue> const* contextState_;
0280   };
0281 
0282   /**
0283    * The aim of this class is to do necessary per-event "initialization" in analyze()
0284    * - setting the current device
0285    * - synchronizing between queues if necessary
0286    * and enforce that those get done in a proper way in RAII fashion.
0287    */
0288   template <typename TQueue, typename>
0289   class ScopedContextAnalyze : public impl::ScopedContextGetterBase<TQueue> {
0290   public:
0291     using Queue = TQueue;
0292     using ScopedContextGetterBase = impl::ScopedContextGetterBase<Queue>;
0293     using ScopedContextGetterBase::queue;
0294     using ScopedContextGetterBase::queuePtr;
0295 
0296     /// Constructor to (possibly) re-use a queue
0297     explicit ScopedContextAnalyze(ProductBase<Queue> const& data) : ScopedContextGetterBase(data) {}
0298   };
0299 
0300 }  // namespace cms::alpakatools
0301 
0302 #endif  // HeterogeneousCore_AlpakaCore_interface_ScopedContext_h