File indexing completed on 2024-04-06 12:15:42
0001 #ifndef HeterogeneousCore_AlpakaTest_interface_AlpakaESTestData_h
0002 #define HeterogeneousCore_AlpakaTest_interface_AlpakaESTestData_h
0003
0004 #include "DataFormats/Portable/interface/PortableHostCollection.h"
0005 #include "DataFormats/Portable/interface/PortableCollection.h"
0006 #include "HeterogeneousCore/AlpakaInterface/interface/CopyToDevice.h"
0007 #include "HeterogeneousCore/AlpakaInterface/interface/config.h"
0008 #include "HeterogeneousCore/AlpakaInterface/interface/memory.h"
0009 #include "HeterogeneousCore/AlpakaTest/interface/AlpakaESTestSoA.h"
0010
0011 namespace cms::alpakatest {
0012
0013 using AlpakaESTestDataAHost = PortableHostCollection<AlpakaESTestSoAA>;
0014 using AlpakaESTestDataCHost = PortableHostCollection<AlpakaESTestSoAC>;
0015 using AlpakaESTestDataDHost = PortableHostCollection<AlpakaESTestSoAD>;
0016
0017
0018 template <typename TDev>
0019 class AlpakaESTestDataB {
0020 public:
0021 using Buffer = cms::alpakatools::device_buffer<TDev, int[]>;
0022 using ConstBuffer = cms::alpakatools::const_device_buffer<TDev, int[]>;
0023
0024 explicit AlpakaESTestDataB(Buffer buffer) : buffer_(std::move(buffer)) {}
0025
0026 Buffer buffer() { return buffer_; }
0027 ConstBuffer buffer() const { return buffer_; }
0028 ConstBuffer const_buffer() const { return buffer_; }
0029
0030 int const* data() const { return buffer_.data(); }
0031 auto size() const { return alpaka::getExtentProduct(buffer_); }
0032
0033 private:
0034 Buffer buffer_;
0035 };
0036
0037
0038
0039 template <typename TDev>
0040 class AlpakaESTestDataE {
0041 public:
0042 using ECollection = PortableCollection<AlpakaESTestSoAE, TDev>;
0043 using EDataCollection = PortableCollection<AlpakaESTestSoAEData, TDev>;
0044
0045 class ConstView {
0046 public:
0047 constexpr ConstView(typename ECollection::ConstView e, typename EDataCollection::ConstView data)
0048 : eView_(e), dataView_(data) {}
0049
0050 constexpr auto size() const { return eView_.metadata().size(); }
0051 constexpr int val(int i) const { return eView_.val(i); }
0052 constexpr int val2(int i) const { return dataView_.val2(eView_.ind(i)); }
0053
0054 private:
0055 typename ECollection::ConstView eView_;
0056 typename EDataCollection::ConstView dataView_;
0057 };
0058
0059 AlpakaESTestDataE(size_t size, size_t dataSize) : e_(size), data_(dataSize) {}
0060
0061 AlpakaESTestDataE(ECollection e, EDataCollection data) : e_(std::move(e)), data_(std::move(data)) {}
0062
0063 ECollection const& e() const { return e_; }
0064 EDataCollection const& data() const { return data_; }
0065
0066 ConstView view() const { return const_view(); }
0067 ConstView const_view() const { return ConstView(e_.const_view(), data_.const_view()); }
0068
0069 private:
0070 ECollection e_;
0071 EDataCollection data_;
0072 };
0073 using AlpakaESTestDataEHost = AlpakaESTestDataE<alpaka_common::DevHost>;
0074
0075 }
0076
0077 namespace cms::alpakatools {
0078
0079
0080
0081
0082 template <>
0083 struct CopyToDevice<cms::alpakatest::AlpakaESTestDataB<alpaka_common::DevHost>> {
0084 template <typename TQueue>
0085 static auto copyAsync(TQueue& queue, cms::alpakatest::AlpakaESTestDataB<alpaka_common::DevHost> const& srcData) {
0086
0087
0088
0089
0090
0091
0092
0093
0094
0095
0096 auto dstBuffer = cms::alpakatools::make_device_buffer<int[]>(queue, srcData.size());
0097 alpaka::memcpy(queue, dstBuffer, srcData.buffer());
0098 return cms::alpakatest::AlpakaESTestDataB<alpaka::Dev<TQueue>>(std::move(dstBuffer));
0099 }
0100 };
0101
0102 template <>
0103 struct CopyToDevice<cms::alpakatest::AlpakaESTestDataEHost> {
0104 template <typename TQueue>
0105 static auto copyAsync(TQueue& queue, cms::alpakatest::AlpakaESTestDataEHost const& srcData) {
0106 using ECopy = CopyToDevice<cms::alpakatest::AlpakaESTestDataEHost::ECollection>;
0107 using EDataCopy = CopyToDevice<cms::alpakatest::AlpakaESTestDataEHost::EDataCollection>;
0108 using TDevice = alpaka::Dev<TQueue>;
0109 return cms::alpakatest::AlpakaESTestDataE<TDevice>(ECopy::copyAsync(queue, srcData.e()),
0110 EDataCopy::copyAsync(queue, srcData.data()));
0111 }
0112 };
0113 }
0114
0115 #endif