Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:05:00

0001 #ifndef DataFormats_Portable_interface_PortableCollection_h
0002 #define DataFormats_Portable_interface_PortableCollection_h
0003 
0004 #include <alpaka/alpaka.hpp>
0005 
0006 #include "DataFormats/Portable/interface/PortableHostCollection.h"
0007 #include "DataFormats/Portable/interface/PortableDeviceCollection.h"
0008 #include "HeterogeneousCore/AlpakaInterface/interface/CopyToDevice.h"
0009 #include "HeterogeneousCore/AlpakaInterface/interface/CopyToHost.h"
0010 
0011 namespace traits {
0012 
0013   // trait for a generic SoA-based product
0014   template <typename T, typename TDev, typename = std::enable_if_t<alpaka::isDevice<TDev>>>
0015   struct PortableCollectionTrait {
0016     using CollectionType = PortableDeviceCollection<T, TDev>;
0017   };
0018 
0019   // specialise for host device
0020   template <typename T>
0021   struct PortableCollectionTrait<T, alpaka_common::DevHost> {
0022     using CollectionType = PortableHostCollection<T>;
0023   };
0024 
0025   // trait for a generic multi-SoA-based product
0026   template <typename TDev, typename T0, typename... Args>
0027   struct PortableMultiCollectionTrait {
0028     using CollectionType = PortableDeviceMultiCollection<TDev, T0, Args...>;
0029   };
0030 
0031   // specialise for host device
0032   template <typename T0, typename... Args>
0033   struct PortableMultiCollectionTrait<alpaka_common::DevHost, T0, Args...> {
0034     using CollectionType = PortableHostMultiCollection<T0, Args...>;
0035   };
0036 
0037 }  // namespace traits
0038 
0039 // type alias for a generic SoA-based product
0040 template <typename T, typename TDev, typename = std::enable_if_t<alpaka::isDevice<TDev>>>
0041 using PortableCollection = typename traits::PortableCollectionTrait<T, TDev>::CollectionType;
0042 
0043 // type alias for a generic SoA-based product
0044 template <typename TDev, typename T0, typename... Args>
0045 using PortableMultiCollection = typename traits::PortableMultiCollectionTrait<TDev, T0, Args...>::CollectionType;
0046 
0047 // define how to copy PortableCollection between host and device
0048 namespace cms::alpakatools {
0049   template <typename TLayout, typename TDevice>
0050   struct CopyToHost<PortableDeviceCollection<TLayout, TDevice>> {
0051     template <typename TQueue>
0052     static auto copyAsync(TQueue& queue, PortableDeviceCollection<TLayout, TDevice> const& srcData) {
0053       PortableHostCollection<TLayout> dstData(srcData->metadata().size(), queue);
0054       alpaka::memcpy(queue, dstData.buffer(), srcData.buffer());
0055       return dstData;
0056     }
0057   };
0058 
0059   template <typename TDev, typename T0, typename... Args>
0060   struct CopyToHost<PortableDeviceMultiCollection<TDev, T0, Args...>> {
0061     template <typename TQueue>
0062     static auto copyAsync(TQueue& queue, PortableDeviceMultiCollection<TDev, T0, Args...> const& srcData) {
0063       PortableHostMultiCollection<T0, Args...> dstData(srcData.sizes(), queue);
0064       alpaka::memcpy(queue, dstData.buffer(), srcData.buffer());
0065       return dstData;
0066     }
0067   };
0068 
0069   template <typename TLayout>
0070   struct CopyToDevice<PortableHostCollection<TLayout>> {
0071     template <typename TQueue>
0072     static auto copyAsync(TQueue& queue, PortableHostCollection<TLayout> const& srcData) {
0073       using TDevice = typename alpaka::trait::DevType<TQueue>::type;
0074       PortableDeviceCollection<TLayout, TDevice> dstData(srcData->metadata().size(), queue);
0075       alpaka::memcpy(queue, dstData.buffer(), srcData.buffer());
0076       return dstData;
0077     }
0078   };
0079 
0080   template <typename TDev, typename T0, typename... Args>
0081   struct CopyToDevice<PortableHostMultiCollection<TDev, T0, Args...>> {
0082     template <typename TQueue>
0083     static auto copyAsync(TQueue& queue, PortableHostMultiCollection<TDev, T0, Args...> const& srcData) {
0084       using TDevice = typename alpaka::trait::DevType<TQueue>::type;
0085       PortableDeviceMultiCollection<TDev, T0, Args...> dstData(srcData.sizes(), queue);
0086       alpaka::memcpy(queue, dstData.buffer(), srcData.buffer());
0087       return dstData;
0088     }
0089   };
0090 }  // namespace cms::alpakatools
0091 
0092 #endif  // DataFormats_Portable_interface_PortableCollection_h