Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-12-19 04:04:41

0001 #ifndef HeterogeneousCore_AlpakaInterface_interface_CopyToDeviceCache_h
0002 #define HeterogeneousCore_AlpakaInterface_interface_CopyToDeviceCache_h
0003 
0004 #include <alpaka/alpaka.hpp>
0005 
0006 #include "HeterogeneousCore/AlpakaCore/interface/QueueCache.h"
0007 #include "HeterogeneousCore/AlpakaInterface/interface/CopyToDevice.h"
0008 #include "HeterogeneousCore/AlpakaInterface/interface/devices.h"
0009 
0010 namespace cms::alpakatools {
0011   namespace detail {
0012     // By default copy the host object with CopyToDevice<T>
0013     //
0014     // Doing with template specialization (rather than
0015     // std::conditional_t and if constexpr) because the
0016     // CopyToDevice<THostObject>::copyAsync() is ill-defined e.g. for
0017     // PortableCollection on host device
0018     template <typename TDevice, typename THostObject>
0019     class CopyToDeviceCacheImpl {
0020     public:
0021       using Device = TDevice;
0022       using Queue = alpaka::Queue<Device, alpaka::NonBlocking>;
0023       using HostObject = THostObject;
0024       using Copy = CopyToDevice<HostObject>;
0025       using DeviceObject = decltype(Copy::copyAsync(std::declval<Queue&>(), std::declval<HostObject const&>()));
0026 
0027       CopyToDeviceCacheImpl(HostObject const& srcObject) {
0028         using Platform = alpaka::Platform<Device>;
0029         auto const& devices = cms::alpakatools::devices<Platform>();
0030         std::vector<std::shared_ptr<Queue>> queues;
0031         queues.reserve(devices.size());
0032         data_.reserve(devices.size());
0033         for (auto const& dev : devices) {
0034           auto queue = getQueueCache<Queue>().get(dev);
0035           data_.emplace_back(Copy::copyAsync(*queue, srcObject));
0036           queues.emplace_back(std::move(queue));
0037         }
0038         for (auto& queuePtr : queues) {
0039           alpaka::wait(*queuePtr);
0040         }
0041       }
0042 
0043       DeviceObject const& get(size_t i) const { return data_[i]; }
0044 
0045     private:
0046       std::vector<DeviceObject> data_;
0047     };
0048 
0049     // For host device, copy the host object directly instead
0050     template <typename THostObject>
0051     class CopyToDeviceCacheImpl<alpaka_common::DevHost, THostObject> {
0052     public:
0053       using HostObject = THostObject;
0054       using DeviceObject = HostObject;
0055 
0056       CopyToDeviceCacheImpl(HostObject const& srcObject) : data_(srcObject) {}
0057 
0058       DeviceObject const& get(size_t i) const { return data_; }
0059 
0060     private:
0061       HostObject data_;
0062     };
0063   }  // namespace detail
0064 
0065   /**
0066    * This class template implements a cache for data that is moved
0067    * from the host (of type THostObject) to all the devices
0068    * corresponding to the TDevice device type.
0069    *
0070    * The host-side object to be copied is given as an argument to the
0071    * class constructor. The constructor uses the
0072    * CopyToDevice<THostObject> class template to perfom the copy, and
0073    * waits for the data copies to finish, i.e. the constructor is
0074    * synchronous wrt. the data copies.
0075    *
0076    * The device-side object corresponding to the THostObject (actual
0077    * type is the return type of CopyToDevice<THostObject>::copyAsync())
0078    * can be obtained with get() member function, that has either the
0079    * queue or device argument.
0080    */
0081   template <typename TDevice, typename THostObject>
0082     requires alpaka::isDevice<TDevice>
0083   class CopyToDeviceCache {
0084     using Device = TDevice;
0085     using HostObject = THostObject;
0086     using Impl = detail::CopyToDeviceCacheImpl<Device, HostObject>;
0087     using DeviceObject = typename Impl::DeviceObject;
0088 
0089   public:
0090     CopyToDeviceCache(THostObject const& srcData) : data_(srcData) {}
0091 
0092     DeviceObject const& get(Device const& dev) const { return data_.get(alpaka::getNativeHandle(dev)); }
0093 
0094     template <typename TQueue>
0095     DeviceObject const& get(TQueue const& queue) const {
0096       return get(alpaka::getDev(queue));
0097     }
0098 
0099   private:
0100     Impl data_;
0101   };
0102 }  // namespace cms::alpakatools
0103 
0104 #endif