File indexing completed on 2024-09-07 04:36:35
0001 #ifndef HeterogeneousCore_CUDACore_ScopedContext_h
0002 #define HeterogeneousCore_CUDACore_ScopedContext_h
0003
0004 #include <optional>
0005
0006 #include "CUDADataFormats/Common/interface/Product.h"
0007 #include "FWCore/Concurrency/interface/WaitingTaskWithArenaHolder.h"
0008 #include "FWCore/Framework/interface/Event.h"
0009 #include "FWCore/Utilities/interface/EDGetToken.h"
0010 #include "FWCore/Utilities/interface/EDPutToken.h"
0011 #include "FWCore/Utilities/interface/StreamID.h"
0012 #include "HeterogeneousCore/CUDACore/interface/ContextState.h"
0013 #include "HeterogeneousCore/CUDAUtilities/interface/EventCache.h"
0014 #include "HeterogeneousCore/CUDAUtilities/interface/SharedEventPtr.h"
0015 #include "HeterogeneousCore/CUDAUtilities/interface/SharedStreamPtr.h"
0016
0017 namespace cms {
0018 namespace cudatest {
0019 class TestScopedContext;
0020 }
0021
0022 namespace cuda {
0023
0024 namespace impl {
0025
0026 class ScopedContextBase {
0027 public:
0028 int device() const { return currentDevice_; }
0029
0030
0031
0032
0033
0034 cudaStream_t stream() const { return stream_.get(); }
0035 const SharedStreamPtr& streamPtr() const { return stream_; }
0036
0037 protected:
0038
0039
0040
0041
0042
0043
0044 explicit ScopedContextBase(edm::StreamID streamID);
0045
0046 explicit ScopedContextBase(const ProductBase& data);
0047
0048 explicit ScopedContextBase(int device, SharedStreamPtr stream);
0049
0050 private:
0051 int currentDevice_;
0052 SharedStreamPtr stream_;
0053 };
0054
0055 class ScopedContextGetterBase : public ScopedContextBase {
0056 public:
0057 template <typename T>
0058 const T& get(const Product<T>& data) {
0059 synchronizeStreams(data.device(), data.stream(), data.isAvailable(), data.event());
0060 return data.data_;
0061 }
0062
0063 template <typename T>
0064 const T& get(const edm::Event& iEvent, edm::EDGetTokenT<Product<T>> token) {
0065 return get(iEvent.get(token));
0066 }
0067
0068 protected:
0069 template <typename... Args>
0070 ScopedContextGetterBase(Args&&... args) : ScopedContextBase(std::forward<Args>(args)...) {}
0071
0072 void synchronizeStreams(int dataDevice, cudaStream_t dataStream, bool available, cudaEvent_t dataEvent);
0073 };
0074
0075 class ScopedContextHolderHelper {
0076 public:
0077 ScopedContextHolderHelper(edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0078 : waitingTaskHolder_{std::move(waitingTaskHolder)} {}
0079
0080 template <typename F>
0081 void pushNextTask(F&& f, ContextState const* state);
0082
0083 void replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder waitingTaskHolder) {
0084 waitingTaskHolder_ = std::move(waitingTaskHolder);
0085 }
0086
0087 void enqueueCallback(int device, cudaStream_t stream);
0088
0089 private:
0090 edm::WaitingTaskWithArenaHolder waitingTaskHolder_;
0091 };
0092 }
0093
0094
0095
0096
0097
0098
0099
0100
0101 class ScopedContextAcquire : public impl::ScopedContextGetterBase {
0102 public:
0103
0104 explicit ScopedContextAcquire(edm::StreamID streamID, edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0105 : ScopedContextGetterBase(streamID), holderHelper_{std::move(waitingTaskHolder)} {}
0106
0107
0108 explicit ScopedContextAcquire(edm::StreamID streamID,
0109 edm::WaitingTaskWithArenaHolder waitingTaskHolder,
0110 ContextState& state)
0111 : ScopedContextGetterBase(streamID), holderHelper_{std::move(waitingTaskHolder)}, contextState_{&state} {}
0112
0113
0114 explicit ScopedContextAcquire(const ProductBase& data, edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0115 : ScopedContextGetterBase(data), holderHelper_{std::move(waitingTaskHolder)} {}
0116
0117
0118 explicit ScopedContextAcquire(const ProductBase& data,
0119 edm::WaitingTaskWithArenaHolder waitingTaskHolder,
0120 ContextState& state)
0121 : ScopedContextGetterBase(data), holderHelper_{std::move(waitingTaskHolder)}, contextState_{&state} {}
0122
0123 ~ScopedContextAcquire() noexcept(false);
0124
0125 template <typename F>
0126 void pushNextTask(F&& f) {
0127 if (contextState_ == nullptr)
0128 throwNoState();
0129 holderHelper_.pushNextTask(std::forward<F>(f), contextState_);
0130 }
0131
0132 void replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder waitingTaskHolder) {
0133 holderHelper_.replaceWaitingTaskHolder(std::move(waitingTaskHolder));
0134 }
0135
0136 private:
0137 void throwNoState();
0138
0139 impl::ScopedContextHolderHelper holderHelper_;
0140 ContextState* contextState_ = nullptr;
0141 };
0142
0143
0144
0145
0146
0147
0148
0149 class ScopedContextProduce : public impl::ScopedContextGetterBase {
0150 public:
0151
0152 explicit ScopedContextProduce(edm::StreamID streamID) : ScopedContextGetterBase(streamID) {}
0153
0154
0155 explicit ScopedContextProduce(const ProductBase& data) : ScopedContextGetterBase(data) {}
0156
0157
0158 explicit ScopedContextProduce(ContextState& state)
0159 : ScopedContextGetterBase(state.device(), state.releaseStreamPtr()) {}
0160
0161
0162 ~ScopedContextProduce();
0163
0164 template <typename T>
0165 std::unique_ptr<Product<T>> wrap(T data) {
0166
0167 return std::unique_ptr<Product<T>>(new Product<T>(device(), streamPtr(), event_, std::move(data)));
0168 }
0169
0170 template <typename T, typename... Args>
0171 auto emplace(edm::Event& iEvent, edm::EDPutTokenT<T> token, Args&&... args) {
0172 return iEvent.emplace(token, device(), streamPtr(), event_, std::forward<Args>(args)...);
0173 }
0174
0175 private:
0176 friend class cudatest::TestScopedContext;
0177
0178
0179 explicit ScopedContextProduce(int device, SharedStreamPtr stream, SharedEventPtr event)
0180 : ScopedContextGetterBase(device, std::move(stream)), event_{std::move(event)} {}
0181
0182
0183 SharedEventPtr event_ = getEventCache().get();
0184 };
0185
0186
0187
0188
0189
0190
0191
0192 class ScopedContextTask : public impl::ScopedContextBase {
0193 public:
0194
0195 explicit ScopedContextTask(ContextState const* state, edm::WaitingTaskWithArenaHolder waitingTaskHolder)
0196 : ScopedContextBase(state->device(), state->streamPtr()),
0197 holderHelper_{std::move(waitingTaskHolder)},
0198 contextState_{state} {}
0199
0200 ~ScopedContextTask();
0201
0202 template <typename F>
0203 void pushNextTask(F&& f) {
0204 holderHelper_.pushNextTask(std::forward<F>(f), contextState_);
0205 }
0206
0207 void replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder waitingTaskHolder) {
0208 holderHelper_.replaceWaitingTaskHolder(std::move(waitingTaskHolder));
0209 }
0210
0211 private:
0212 impl::ScopedContextHolderHelper holderHelper_;
0213 ContextState const* contextState_;
0214 };
0215
0216
0217
0218
0219
0220
0221
0222 class ScopedContextAnalyze : public impl::ScopedContextGetterBase {
0223 public:
0224
0225 explicit ScopedContextAnalyze(const ProductBase& data) : ScopedContextGetterBase(data) {}
0226 };
0227
0228 namespace impl {
0229 template <typename F>
0230 void ScopedContextHolderHelper::pushNextTask(F&& f, ContextState const* state) {
0231 auto group = waitingTaskHolder_.group();
0232 replaceWaitingTaskHolder(edm::WaitingTaskWithArenaHolder{
0233 *group,
0234 edm::make_waiting_task_with_holder(std::move(waitingTaskHolder_),
0235 [state, func = std::forward<F>(f)](edm::WaitingTaskWithArenaHolder h) {
0236 func(ScopedContextTask{state, std::move(h)});
0237 })});
0238 }
0239 }
0240 }
0241 }
0242
0243 #endif