File indexing completed on 2024-04-06 12:03:45
0001 #include "catch.hpp"
0002
0003 #include "CUDADataFormats/Common/interface/Product.h"
0004 #include "HeterogeneousCore/CUDACore/interface/ScopedContext.h"
0005 #include "HeterogeneousCore/CUDAUtilities/interface/cudaCheck.h"
0006 #include "HeterogeneousCore/CUDAUtilities/interface/requireDevices.h"
0007 #include "HeterogeneousCore/CUDAUtilities/interface/StreamCache.h"
0008 #include "HeterogeneousCore/CUDAUtilities/interface/EventCache.h"
0009
0010 #include <cuda_runtime_api.h>
0011
0012 namespace cms::cudatest {
0013 class TestScopedContext {
0014 public:
0015 static cuda::ScopedContextProduce make(int dev, bool createEvent) {
0016 cms::cuda::SharedEventPtr event;
0017 if (createEvent) {
0018 event = cms::cuda::getEventCache().get();
0019 }
0020 return cuda::ScopedContextProduce(dev, cms::cuda::getStreamCache().get(), std::move(event));
0021 }
0022 };
0023 }
0024
0025 TEST_CASE("Use of cms::cuda::Product template", "[CUDACore]") {
0026 SECTION("Default constructed") {
0027 auto foo = cms::cuda::Product<int>();
0028 REQUIRE(!foo.isValid());
0029
0030 auto bar = std::move(foo);
0031 }
0032
0033 if (not cms::cudatest::testDevices()) {
0034 return;
0035 }
0036
0037 constexpr int defaultDevice = 0;
0038 cudaCheck(cudaSetDevice(defaultDevice));
0039 {
0040 auto ctx = cms::cudatest::TestScopedContext::make(defaultDevice, true);
0041 std::unique_ptr<cms::cuda::Product<int>> dataPtr = ctx.wrap(10);
0042 auto& data = *dataPtr;
0043
0044 SECTION("Construct from cms::cuda::ScopedContext") {
0045 REQUIRE(data.isValid());
0046 REQUIRE(data.device() == defaultDevice);
0047 REQUIRE(data.stream() == ctx.stream());
0048 REQUIRE(data.event() != nullptr);
0049 }
0050
0051 SECTION("Move constructor") {
0052 auto data2 = cms::cuda::Product<int>(std::move(data));
0053 REQUIRE(data2.isValid());
0054 REQUIRE(!data.isValid());
0055 }
0056
0057 SECTION("Move assignment") {
0058 cms::cuda::Product<int> data2;
0059 data2 = std::move(data);
0060 REQUIRE(data2.isValid());
0061 REQUIRE(!data.isValid());
0062 }
0063 }
0064
0065 cudaCheck(cudaSetDevice(defaultDevice));
0066 cudaCheck(cudaDeviceSynchronize());
0067
0068 }