File indexing completed on 2025-01-08 03:36:18
0001 #ifndef HeterogeneousCore_AlpakaCore_interface_alpaka_ESProducer_h
0002 #define HeterogeneousCore_AlpakaCore_interface_alpaka_ESProducer_h
0003
0004 #include "FWCore/Framework/interface/ESProducerExternalWork.h"
0005 #include "FWCore/Framework/interface/MakeDataException.h"
0006 #include "FWCore/Framework/interface/produce_helpers.h"
0007 #include "HeterogeneousCore/AlpakaCore/interface/modulePrevalidate.h"
0008 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/ESDeviceProduct.h"
0009 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/ESDeviceProductType.h"
0010 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/Record.h"
0011 #include "HeterogeneousCore/AlpakaInterface/interface/devices.h"
0012 #include "HeterogeneousCore/AlpakaInterface/interface/CopyToDevice.h"
0013
0014 #include <functional>
0015
0016 namespace ALPAKA_ACCELERATOR_NAMESPACE {
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027 class ESProducer : public edm::ESProducerExternalWork {
0028 using Base = edm::ESProducerExternalWork;
0029
0030 public:
0031 static void prevalidate(edm::ConfigurationDescriptions& descriptions) {
0032 Base::prevalidate(descriptions);
0033 cms::alpakatools::modulePrevalidate(descriptions);
0034 }
0035
0036 protected:
0037 ESProducer(edm::ParameterSet const& iConfig);
0038
0039 template <typename T>
0040 auto setWhatProduced(T* iThis, edm::es::Label const& label = {}) {
0041 return setWhatProduced(iThis, &T::produce, label);
0042 }
0043
0044 template <typename T, typename TReturn, typename TRecord>
0045 auto setWhatProduced(T* iThis, TReturn (T ::*iMethod)(TRecord const&), edm::es::Label const& label = {}) {
0046 auto cc = Base::setWhatProduced(iThis, iMethod, label);
0047 using TProduct = typename edm::eventsetup::produce::smart_pointer_traits<TReturn>::type;
0048 if constexpr (not detail::useESProductDirectly) {
0049
0050 auto tokenPtr = std::make_shared<edm::ESGetToken<TProduct, TRecord>>();
0051 auto ccDev = setWhatProducedDevice<TRecord>(
0052 [tokenPtr](device::Record<TRecord> const& iRecord) {
0053 using CopyT = cms::alpakatools::CopyToDevice<TProduct>;
0054 try {
0055 auto handle = iRecord.getTransientHandle(*tokenPtr);
0056 return std::optional{CopyT::copyAsync(iRecord.queue(), *handle)};
0057 } catch (edm::eventsetup::MakeDataException& e) {
0058 return std::optional<decltype(CopyT::copyAsync(std::declval<Queue&>(), std::declval<TProduct>()))>();
0059 }
0060 },
0061 label);
0062 *tokenPtr = ccDev.consumes(edm::ESInputTag{moduleLabel_, label.default_ + appendToDataLabel_});
0063 }
0064 return cc;
0065 }
0066
0067 template <typename T, typename TReturn, typename TRecord>
0068 auto setWhatProduced(T* iThis,
0069 TReturn (T ::*iMethod)(device::Record<TRecord> const&),
0070 edm::es::Label const& label = {}) {
0071 using TProduct = typename edm::eventsetup::produce::smart_pointer_traits<TReturn>::type;
0072 if constexpr (detail::useESProductDirectly) {
0073 return Base::setWhatProduced(
0074 [iThis, iMethod](TRecord const& record) {
0075 auto const& devices = cms::alpakatools::devices<Platform>();
0076 assert(devices.size() == 1);
0077 device::Record<TRecord> const deviceRecord(record, devices.front());
0078 static_assert(std::is_same_v<std::remove_cvref_t<decltype(deviceRecord.queue())>,
0079 alpaka::Queue<Device, alpaka::Blocking>>,
0080 "Non-blocking queue when trying to use ES data product directly. This might indicate a "
0081 "need to extend the Alpaka ESProducer base class.");
0082 return std::invoke(iMethod, iThis, deviceRecord);
0083 },
0084 label);
0085 } else {
0086 return setWhatProducedDevice<TRecord>(
0087 [iThis, iMethod](device::Record<TRecord> const& record) { return std::invoke(iMethod, iThis, record); },
0088 label);
0089 }
0090 }
0091
0092 private:
0093 template <typename TRecord, typename TFunc>
0094 auto setWhatProducedDevice(TFunc&& func, const edm::es::Label& label) {
0095 using Types = edm::eventsetup::impl::ReturnArgumentTypes<TFunc>;
0096 using TReturn = typename Types::return_type;
0097 using TProduct = typename edm::eventsetup::produce::smart_pointer_traits<TReturn>::type;
0098 using ProductType = ESDeviceProduct<TProduct>;
0099 using ReturnType = detail::ESDeviceProductWithStorage<TProduct, TReturn>;
0100 return Base::setWhatAcquiredProducedWithLambda(
0101
0102 [function = std::forward<TFunc>(func), synchronize = synchronize_](TRecord const& record,
0103 edm::WaitingTaskWithArenaHolder holder) {
0104
0105 auto const& devices = cms::alpakatools::devices<Platform>();
0106 auto ret = std::make_unique<ReturnType>(devices.size());
0107 bool allnull = true;
0108 bool anynull = false;
0109 for (auto const& dev : devices) {
0110 device::Record<TRecord> const deviceRecord(record, dev);
0111 auto prod = function(deviceRecord);
0112 if (prod) {
0113 allnull = false;
0114 ret->insert(dev, std::move(prod));
0115 } else {
0116 anynull = true;
0117 }
0118 if (synchronize) {
0119 alpaka::wait(deviceRecord.queue());
0120 } else {
0121 enqueueCallback(deviceRecord.queue(), std::move(holder));
0122 }
0123
0124
0125
0126
0127
0128
0129
0130
0131 }
0132 return std::tuple(std::move(ret), allnull, anynull);
0133 },
0134
0135 [](TRecord const& record, auto fromAcquire) -> std::unique_ptr<ProductType> {
0136 auto [ret, allnull, anynull] = std::move(fromAcquire);
0137
0138
0139
0140 if (allnull) {
0141 return nullptr;
0142 } else if (anynull) {
0143
0144
0145
0146
0147
0148
0149
0150
0151 throwSomeNullException();
0152 }
0153 return std::move(ret);
0154 },
0155 label);
0156 }
0157
0158 static void enqueueCallback(Queue& queue, edm::WaitingTaskWithArenaHolder holder);
0159 static void throwSomeNullException();
0160
0161 std::string const moduleLabel_;
0162 std::string const appendToDataLabel_;
0163 bool const synchronize_;
0164 };
0165 }
0166
0167 #endif