Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-02-06 03:08:57

0001 #ifndef DataFormats_Portable_interface_ProductBase_h
0002 #define DataFormats_Portable_interface_ProductBase_h
0003 
0004 #include <atomic>
0005 #include <memory>
0006 #include <utility>
0007 
0008 #include <alpaka/alpaka.hpp>
0009 
0010 #include "HeterogeneousCore/AlpakaInterface/interface/ScopedContextFwd.h"
0011 
0012 namespace cms::alpakatools {
0013 
0014   /**
0015    * Base class for all instantiations of Product<TQueue, T> to hold the
0016    * non-T-dependent members.
0017    */
0018   template <typename TQueue, typename = std::enable_if_t<alpaka::isQueue<TQueue>>>
0019   class ProductBase {
0020   public:
0021     using Queue = TQueue;
0022     using Event = alpaka::Event<Queue>;
0023     using Device = alpaka::Dev<Queue>;
0024 
0025     ProductBase() = default;  // Needed only for ROOT dictionary generation
0026 
0027     ~ProductBase() {
0028       // Make sure that the production of the product in the GPU is
0029       // complete before destructing the product. This is to make sure
0030       // that the EDM stream does not move to the next event before all
0031       // asynchronous processing of the current is complete.
0032 
0033       // TODO: a callback notifying a WaitingTaskHolder (or similar)
0034       // would avoid blocking the CPU, but would also require more work.
0035 
0036       // FIXME: this may throw an execption if the underlaying call fails.
0037       if (event_) {
0038         alpaka::wait(*event_);
0039       }
0040     }
0041 
0042     ProductBase(const ProductBase&) = delete;
0043     ProductBase& operator=(const ProductBase&) = delete;
0044     ProductBase(ProductBase&& other)
0045         : queue_{std::move(other.queue_)}, event_{std::move(other.event_)}, mayReuseQueue_{other.mayReuseQueue_.load()} {}
0046     ProductBase& operator=(ProductBase&& other) {
0047       queue_ = std::move(other.queue_);
0048       event_ = std::move(other.event_);
0049       mayReuseQueue_ = other.mayReuseQueue_.load();
0050       return *this;
0051     }
0052 
0053     bool isValid() const { return queue_.get() != nullptr; }
0054 
0055     bool isAvailable() const {
0056       // if default-constructed, the product is not available
0057       if (not event_) {
0058         return false;
0059       }
0060       return alpaka::isComplete(*event_);
0061     }
0062 
0063     // returning a const& requires changes in alpaka's getDev() implementations
0064     Device device() const { return alpaka::getDev(queue()); }
0065 
0066     Queue const& queue() const { return *queue_; }
0067 
0068     Event const& event() const { return *event_; }
0069 
0070   protected:
0071     explicit ProductBase(std::shared_ptr<Queue> queue, std::shared_ptr<Event> event)
0072         : queue_{std::move(queue)}, event_{std::move(event)} {}
0073 
0074   private:
0075     friend class impl::ScopedContextBase<Queue>;
0076     friend class ScopedContextProduce<Queue>;
0077 
0078     // The following function is intended to be used only from ScopedContext
0079     const std::shared_ptr<Queue>& queuePtr() const { return queue_; }
0080 
0081     bool mayReuseQueue() const {
0082       bool expected = true;
0083       bool changed = mayReuseQueue_.compare_exchange_strong(expected, false);
0084       // If the current thread is the one flipping the flag, it may
0085       // reuse the queue.
0086       return changed;
0087     }
0088 
0089     // shared_ptr because of caching in QueueCache, and sharing across edm::Event products
0090     std::shared_ptr<Queue> queue_;  //!
0091     // shared_ptr because of caching in EventCache
0092     std::shared_ptr<Event> event_;  //!
0093 
0094     // This flag tells whether the queue may be reused by a consumer or not.
0095     // The goal is to have a "chain" of modules to enqueue their work to the same queue.
0096     mutable std::atomic<bool> mayReuseQueue_ = true;  //!
0097   };
0098 
0099 }  // namespace cms::alpakatools
0100 
0101 #endif  // DataFormats_Portable_interface_ProductBase_h