File indexing completed on 2024-04-06 12:15:39
0001 #ifndef HeterogeneousCore_AlpakaCore_interface_alpaka_ESProducer_h
0002 #define HeterogeneousCore_AlpakaCore_interface_alpaka_ESProducer_h
0003
0004 #include "FWCore/Framework/interface/ESProducer.h"
0005 #include "FWCore/Framework/interface/MakeDataException.h"
0006 #include "FWCore/Framework/interface/produce_helpers.h"
0007 #include "HeterogeneousCore/AlpakaCore/interface/module_backend_config.h"
0008 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/ESDeviceProduct.h"
0009 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/ESDeviceProductType.h"
0010 #include "HeterogeneousCore/AlpakaCore/interface/alpaka/Record.h"
0011 #include "HeterogeneousCore/AlpakaInterface/interface/devices.h"
0012 #include "HeterogeneousCore/AlpakaInterface/interface/CopyToDevice.h"
0013
0014 #include <functional>
0015
0016 namespace ALPAKA_ACCELERATOR_NAMESPACE {
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027 class ESProducer : public edm::ESProducer {
0028 using Base = edm::ESProducer;
0029
0030 public:
0031 static void prevalidate(edm::ConfigurationDescriptions& descriptions) {
0032 Base::prevalidate(descriptions);
0033 cms::alpakatools::module_backend_config(descriptions);
0034 }
0035
0036 protected:
0037 ESProducer(edm::ParameterSet const& iConfig);
0038
0039 template <typename T>
0040 auto setWhatProduced(T* iThis, edm::es::Label const& label = {}) {
0041 return setWhatProduced(iThis, &T::produce, label);
0042 }
0043
0044 template <typename T, typename TReturn, typename TRecord>
0045 auto setWhatProduced(T* iThis, TReturn (T ::*iMethod)(TRecord const&), edm::es::Label const& label = {}) {
0046 auto cc = Base::setWhatProduced(iThis, iMethod, label);
0047 using TProduct = typename edm::eventsetup::produce::smart_pointer_traits<TReturn>::type;
0048 if constexpr (not detail::useESProductDirectly<TProduct>) {
0049
0050 auto tokenPtr = std::make_shared<edm::ESGetToken<TProduct, TRecord>>();
0051 auto ccDev = setWhatProducedDevice<TRecord>(
0052 [tokenPtr](device::Record<TRecord> const& iRecord) {
0053 using CopyT = cms::alpakatools::CopyToDevice<TProduct>;
0054 try {
0055 auto handle = iRecord.getTransientHandle(*tokenPtr);
0056 return std::optional{CopyT::copyAsync(iRecord.queue(), *handle)};
0057 } catch (edm::eventsetup::MakeDataException& e) {
0058 return std::optional<decltype(CopyT::copyAsync(std::declval<Queue&>(), std::declval<TProduct>()))>();
0059 }
0060 },
0061 label);
0062 *tokenPtr = ccDev.consumes(edm::ESInputTag{moduleLabel_, label.default_ + appendToDataLabel_});
0063 }
0064 return cc;
0065 }
0066
0067 template <typename T, typename TReturn, typename TRecord>
0068 auto setWhatProduced(T* iThis,
0069 TReturn (T ::*iMethod)(device::Record<TRecord> const&),
0070 edm::es::Label const& label = {}) {
0071 using TProduct = typename edm::eventsetup::produce::smart_pointer_traits<TReturn>::type;
0072 if constexpr (detail::useESProductDirectly<TProduct>) {
0073 return Base::setWhatProduced(
0074 [iThis, iMethod](TRecord const& record) {
0075 auto const& devices = cms::alpakatools::devices<Platform>();
0076 assert(devices.size() == 1);
0077 device::Record<TRecord> const deviceRecord(record, devices.front());
0078 return std::invoke(iMethod, iThis, deviceRecord);
0079 },
0080 label);
0081 } else {
0082 return setWhatProducedDevice<TRecord>(
0083 [iThis, iMethod](device::Record<TRecord> const& record) { return std::invoke(iMethod, iThis, record); },
0084 label);
0085 }
0086 }
0087
0088 private:
0089 template <typename TRecord, typename TFunc>
0090 auto setWhatProducedDevice(TFunc&& func, const edm::es::Label& label) {
0091 using Types = edm::eventsetup::impl::ReturnArgumentTypes<TFunc>;
0092 using TReturn = typename Types::return_type;
0093 using TProduct = typename edm::eventsetup::produce::smart_pointer_traits<TReturn>::type;
0094 using ProductType = ESDeviceProduct<TProduct>;
0095 using ReturnType = detail::ESDeviceProductWithStorage<TProduct, TReturn>;
0096 return Base::setWhatProduced(
0097 [function = std::forward<TFunc>(func)](TRecord const& record) -> std::unique_ptr<ProductType> {
0098
0099 auto const& devices = cms::alpakatools::devices<Platform>();
0100 std::vector<std::shared_ptr<Queue>> queues;
0101 queues.reserve(devices.size());
0102 auto ret = std::make_unique<ReturnType>(devices.size());
0103 bool allnull = true;
0104 bool anynull = false;
0105 for (auto const& dev : devices) {
0106 device::Record<TRecord> const deviceRecord(record, dev);
0107 auto prod = function(deviceRecord);
0108 if (prod) {
0109 allnull = false;
0110 ret->insert(dev, std::move(prod));
0111 } else {
0112 anynull = true;
0113 }
0114 queues.push_back(deviceRecord.queuePtr());
0115 }
0116
0117 for (auto& queuePtr : queues) {
0118 alpaka::wait(*queuePtr);
0119 }
0120 if (allnull) {
0121 return nullptr;
0122 } else if (anynull) {
0123
0124
0125
0126
0127
0128
0129
0130
0131 ESProducer::throwSomeNullException();
0132 }
0133 return ret;
0134 },
0135 label);
0136 }
0137
0138 static void throwSomeNullException();
0139
0140 std::string const moduleLabel_;
0141 std::string const appendToDataLabel_;
0142 };
0143 }
0144
0145 #endif