Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:46

0001 #include "catch.hpp"
0002 
0003 #include "HeterogeneousCore/CUDAUtilities/interface/device_unique_ptr.h"
0004 #include "HeterogeneousCore/CUDAUtilities/interface/host_unique_ptr.h"
0005 #include "HeterogeneousCore/CUDAUtilities/interface/copyAsync.h"
0006 #include "HeterogeneousCore/CUDAUtilities/interface/cudaCheck.h"
0007 #include "HeterogeneousCore/CUDAUtilities/interface/memsetAsync.h"
0008 #include "HeterogeneousCore/CUDAUtilities/interface/requireDevices.h"
0009 
0010 TEST_CASE("memsetAsync", "[cudaMemTools]") {
0011   if (not cms::cudatest::testDevices()) {
0012     return;
0013   }
0014 
0015   cudaStream_t stream;
0016   cudaCheck(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
0017 
0018   SECTION("Single element") {
0019     auto host_orig = cms::cuda::make_host_unique<int>(stream);
0020     *host_orig = 42;
0021 
0022     auto device = cms::cuda::make_device_unique<int>(stream);
0023     auto host = cms::cuda::make_host_unique<int>(stream);
0024     cms::cuda::copyAsync(device, host_orig, stream);
0025     cms::cuda::memsetAsync(device, 0, stream);
0026     cms::cuda::copyAsync(host, device, stream);
0027     cudaCheck(cudaStreamSynchronize(stream));
0028 
0029     REQUIRE(*host == 0);
0030   }
0031 
0032   SECTION("Multiple elements") {
0033     constexpr int N = 100;
0034 
0035     auto host_orig = cms::cuda::make_host_unique<int[]>(N, stream);
0036     for (int i = 0; i < N; ++i) {
0037       host_orig[i] = i;
0038     }
0039 
0040     auto device = cms::cuda::make_device_unique<int[]>(N, stream);
0041     auto host = cms::cuda::make_host_unique<int[]>(N, stream);
0042     cms::cuda::copyAsync(device, host_orig, N, stream);
0043     cms::cuda::memsetAsync(device, 0, N, stream);
0044     cms::cuda::copyAsync(host, device, N, stream);
0045     cudaCheck(cudaStreamSynchronize(stream));
0046 
0047     for (int i = 0; i < N; ++i) {
0048       CHECK(host[i] == 0);
0049     }
0050   }
0051 
0052   cudaCheck(cudaStreamDestroy(stream));
0053 }