Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-07-28 22:48:35

0001 #include "HeterogeneousCore/CUDACore/interface/ScopedContext.h"
0002 
0003 #include "FWCore/Concurrency/interface/Async.h"
0004 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0005 #include "FWCore/ServiceRegistry/interface/Service.h"
0006 #include "FWCore/Utilities/interface/Exception.h"
0007 #include "HeterogeneousCore/CUDAUtilities/interface/StreamCache.h"
0008 #include "HeterogeneousCore/CUDAUtilities/interface/cudaCheck.h"
0009 
0010 #include "chooseDevice.h"
0011 
0012 namespace cms::cuda {
0013   namespace impl {
0014     ScopedContextBase::ScopedContextBase(edm::StreamID streamID) : currentDevice_(chooseDevice(streamID)) {
0015       cudaCheck(cudaSetDevice(currentDevice_));
0016       stream_ = getStreamCache().get();
0017     }
0018 
0019     ScopedContextBase::ScopedContextBase(const ProductBase& data) : currentDevice_(data.device()) {
0020       cudaCheck(cudaSetDevice(currentDevice_));
0021       if (data.mayReuseStream()) {
0022         stream_ = data.streamPtr();
0023       } else {
0024         stream_ = getStreamCache().get();
0025       }
0026     }
0027 
0028     ScopedContextBase::ScopedContextBase(int device, SharedStreamPtr stream)
0029         : currentDevice_(device), stream_(std::move(stream)) {
0030       cudaCheck(cudaSetDevice(currentDevice_));
0031     }
0032 
0033     ////////////////////
0034 
0035     void ScopedContextGetterBase::synchronizeStreams(int dataDevice,
0036                                                      cudaStream_t dataStream,
0037                                                      bool available,
0038                                                      cudaEvent_t dataEvent) {
0039       if (dataDevice != device()) {
0040         // Eventually replace with prefetch to current device (assuming unified memory works)
0041         // If we won't go to unified memory, need to figure out something else...
0042         throw cms::Exception("LogicError") << "Handling data from multiple devices is not yet supported";
0043       }
0044 
0045       if (dataStream != stream()) {
0046         // Different streams, need to synchronize
0047         if (not available) {
0048           // Event not yet occurred, so need to add synchronization
0049           // here. Sychronization is done by making the CUDA stream to
0050           // wait for an event, so all subsequent work in the stream
0051           // will run only after the event has "occurred" (i.e. data
0052           // product became available).
0053           cudaCheck(cudaStreamWaitEvent(stream(), dataEvent, 0), "Failed to make a stream to wait for an event");
0054         }
0055       }
0056     }
0057 
0058     void ScopedContextHolderHelper::enqueueCallback(int device, cudaStream_t stream) {
0059       edm::Service<edm::Async> async;
0060       SharedEventPtr event = getEventCache().get();
0061       cudaCheck(cudaEventRecord(event.get(), stream));
0062       async->runAsync(
0063           std::move(waitingTaskHolder_),
0064           [event = std::move(event)]() mutable { cudaCheck(cudaEventSynchronize(event.get())); },
0065           []() { return "Enqueued by cms::cuda::ScopedContextHolderHelper::enqueueCallback()"; });
0066     }
0067   }  // namespace impl
0068 
0069   ////////////////////
0070 
0071   ScopedContextAcquire::~ScopedContextAcquire() noexcept(false) {
0072     holderHelper_.enqueueCallback(device(), stream());
0073     if (contextState_) {
0074       contextState_->set(device(), streamPtr());
0075     }
0076   }
0077 
0078   void ScopedContextAcquire::throwNoState() {
0079     throw cms::Exception("LogicError")
0080         << "Calling ScopedContextAcquire::insertNextTask() requires ScopedContextAcquire to be constructed with "
0081            "ContextState, but that was not the case";
0082   }
0083 
0084   ////////////////////
0085 
0086   ScopedContextProduce::~ScopedContextProduce() {
0087     // Intentionally not checking the return value to avoid throwing
0088     // exceptions. If this call would fail, we should get failures
0089     // elsewhere as well.
0090     cudaEventRecord(event_.get(), stream());
0091   }
0092 
0093   ////////////////////
0094 
0095   ScopedContextTask::~ScopedContextTask() { holderHelper_.enqueueCallback(device(), stream()); }
0096 }  // namespace cms::cuda