Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:43

0001 #include "catch.hpp"
0002 
0003 #include "CUDADataFormats/Common/interface/Product.h"
0004 #include "FWCore/Concurrency/interface/FinalWaitingTask.h"
0005 #include "FWCore/Concurrency/interface/WaitingTask.h"
0006 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0007 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0008 #include "HeterogeneousCore/CUDAUtilities/interface/device_unique_ptr.h"
0009 #include "HeterogeneousCore/CUDACore/interface/ScopedContext.h"
0010 #include "HeterogeneousCore/CUDAUtilities/interface/cudaCheck.h"
0011 #include "HeterogeneousCore/CUDAUtilities/interface/eventWorkHasCompleted.h"
0012 #include "HeterogeneousCore/CUDAUtilities/interface/requireDevices.h"
0013 #include "HeterogeneousCore/CUDAUtilities/interface/StreamCache.h"
0014 #include "HeterogeneousCore/CUDAUtilities/interface/EventCache.h"
0015 #include "HeterogeneousCore/CUDAUtilities/interface/currentDevice.h"
0016 #include "HeterogeneousCore/CUDAUtilities/interface/ScopedSetDevice.h"
0017 
0018 #include "test_ScopedContextKernels.h"
0019 
0020 #include "oneapi/tbb/task_arena.h"
0021 #include "oneapi/tbb/task_group.h"
0022 
0023 namespace cms::cudatest {
0024   class TestScopedContext {
0025   public:
0026     static cuda::ScopedContextProduce make(int dev, bool createEvent) {
0027       cms::cuda::SharedEventPtr event;
0028       if (createEvent) {
0029         event = cms::cuda::getEventCache().get();
0030       }
0031       return cuda::ScopedContextProduce(dev, cms::cuda::getStreamCache().get(), std::move(event));
0032     }
0033   };
0034 }  // namespace cms::cudatest
0035 
0036 namespace {
0037   std::unique_ptr<cms::cuda::Product<int*>> produce(int device, int* d, int* h) {
0038     auto ctx = cms::cudatest::TestScopedContext::make(device, true);
0039     cudaCheck(cudaMemcpyAsync(d, h, sizeof(int), cudaMemcpyHostToDevice, ctx.stream()));
0040     cms::cudatest::testScopedContextKernels_single(d, ctx.stream());
0041     return ctx.wrap(d);
0042   }
0043 }  // namespace
0044 
0045 TEST_CASE("Use of cms::cuda::ScopedContext", "[CUDACore]") {
0046   if (not cms::cudatest::testDevices()) {
0047     return;
0048   }
0049 
0050   constexpr int defaultDevice = 0;
0051   {
0052     auto ctx = cms::cudatest::TestScopedContext::make(defaultDevice, true);
0053 
0054     SECTION("Construct from device ID") { REQUIRE(cms::cuda::currentDevice() == defaultDevice); }
0055 
0056     SECTION("Wrap T to cms::cuda::Product<T>") {
0057       std::unique_ptr<cms::cuda::Product<int>> dataPtr = ctx.wrap(10);
0058       REQUIRE(dataPtr.get() != nullptr);
0059       REQUIRE(dataPtr->device() == ctx.device());
0060       REQUIRE(dataPtr->stream() == ctx.stream());
0061     }
0062 
0063     SECTION("Construct from from cms::cuda::Product<T>") {
0064       std::unique_ptr<cms::cuda::Product<int>> dataPtr = ctx.wrap(10);
0065       const auto& data = *dataPtr;
0066 
0067       cms::cuda::ScopedContextProduce ctx2{data};
0068       REQUIRE(cms::cuda::currentDevice() == data.device());
0069       REQUIRE(ctx2.stream() == data.stream());
0070 
0071       // Second use of a product should lead to new stream
0072       cms::cuda::ScopedContextProduce ctx3{data};
0073       REQUIRE(cms::cuda::currentDevice() == data.device());
0074       REQUIRE(ctx3.stream() != data.stream());
0075     }
0076 
0077     SECTION("Storing state in cms::cuda::ContextState") {
0078       oneapi::tbb::task_arena arena(1);
0079       arena.execute([&ctx]() {
0080         cms::cuda::ContextState ctxstate;
0081         {  // acquire
0082           std::unique_ptr<cms::cuda::Product<int>> dataPtr = ctx.wrap(10);
0083           const auto& data = *dataPtr;
0084           oneapi::tbb::task_group group;
0085           edm::FinalWaitingTask waitTask{group};
0086           {
0087             edm::WaitingTaskWithArenaHolder dummy{group, &waitTask};
0088             cms::cuda::ScopedContextAcquire ctx2{data, dummy, ctxstate};
0089           }
0090           waitTask.wait();
0091         }
0092 
0093         {  // produce
0094           cms::cuda::ScopedContextProduce ctx2{ctxstate};
0095           REQUIRE(cms::cuda::currentDevice() == ctx.device());
0096           REQUIRE(ctx2.stream() == ctx.stream());
0097         }
0098       });
0099     }
0100 
0101     SECTION("Joining multiple CUDA streams") {
0102       cms::cuda::ScopedSetDevice setDeviceForThisScope(defaultDevice);
0103 
0104       // Mimick a producer on the first CUDA stream
0105       int h_a1 = 1;
0106       auto d_a1 = cms::cuda::make_device_unique<int>(nullptr);
0107       auto wprod1 = produce(defaultDevice, d_a1.get(), &h_a1);
0108 
0109       // Mimick a producer on the second CUDA stream
0110       int h_a2 = 2;
0111       auto d_a2 = cms::cuda::make_device_unique<int>(nullptr);
0112       auto wprod2 = produce(defaultDevice, d_a2.get(), &h_a2);
0113 
0114       REQUIRE(wprod1->stream() != wprod2->stream());
0115 
0116       // Mimick a third producer "joining" the two streams
0117       cms::cuda::ScopedContextProduce ctx2{*wprod1};
0118 
0119       auto prod1 = ctx2.get(*wprod1);
0120       auto prod2 = ctx2.get(*wprod2);
0121 
0122       auto d_a3 = cms::cuda::make_device_unique<int>(nullptr);
0123       cms::cudatest::testScopedContextKernels_join(prod1, prod2, d_a3.get(), ctx2.stream());
0124       cudaCheck(cudaStreamSynchronize(ctx2.stream()));
0125       REQUIRE(wprod2->isAvailable());
0126       REQUIRE(cms::cuda::eventWorkHasCompleted(wprod2->event()));
0127 
0128       h_a1 = 0;
0129       h_a2 = 0;
0130       int h_a3 = 0;
0131 
0132       cudaCheck(cudaMemcpyAsync(&h_a1, d_a1.get(), sizeof(int), cudaMemcpyDeviceToHost, ctx.stream()));
0133       cudaCheck(cudaMemcpyAsync(&h_a2, d_a2.get(), sizeof(int), cudaMemcpyDeviceToHost, ctx.stream()));
0134       cudaCheck(cudaMemcpyAsync(&h_a3, d_a3.get(), sizeof(int), cudaMemcpyDeviceToHost, ctx.stream()));
0135 
0136       REQUIRE(h_a1 == 2);
0137       REQUIRE(h_a2 == 4);
0138       REQUIRE(h_a3 == 6);
0139     }
0140   }
0141 
0142   cudaCheck(cudaSetDevice(defaultDevice));
0143   cudaCheck(cudaDeviceSynchronize());
0144   // Note: CUDA resources are cleaned up by the destructors of the global cache objects
0145 }