Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:46

0001 #include <cassert>
0002 #include <iostream>
0003 #include <limits>
0004 #include <string>
0005 #include <utility>
0006 
0007 #include <hip/hip_runtime.h>
0008 
0009 #include <fmt/core.h>
0010 
0011 #include <catch.hpp>
0012 
0013 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0014 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0015 #include "FWCore/ParameterSetReader/interface/ParameterSetReader.h"
0016 #include "FWCore/ServiceRegistry/interface/Service.h"
0017 #include "FWCore/ServiceRegistry/interface/ServiceRegistry.h"
0018 #include "FWCore/ServiceRegistry/interface/ServiceToken.h"
0019 #include "FWCore/Utilities/interface/Exception.h"
0020 #include "FWCore/Utilities/interface/ResourceInformation.h"
0021 #include "HeterogeneousCore/ROCmServices/interface/ROCmInterface.h"
0022 
0023 namespace {
0024   std::string makeProcess(std::string const& name) {
0025     return fmt::format(R"_(
0026 import FWCore.ParameterSet.Config as cms
0027 process = cms.Process('{}')
0028 )_",
0029                        name);
0030   }
0031 
0032   void addResourceInformationService(std::string& config) {
0033     config += R"_(
0034 process.add_(cms.Service('ResourceInformationService'))
0035   )_";
0036   }
0037 
0038   void addROCmService(std::string& config, bool enabled = true) {
0039     config += fmt::format(R"_(
0040 process.add_(cms.Service('ROCmService',
0041   enabled = cms.untracked.bool({}),
0042   verbose = cms.untracked.bool(True)
0043 ))
0044   )_",
0045                           enabled ? "True" : "False");
0046   }
0047 
0048   edm::ServiceToken getServiceToken(std::string const& config) {
0049     std::unique_ptr<edm::ParameterSet> params;
0050     edm::makeParameterSets(config, params);
0051     return edm::ServiceToken(edm::ServiceRegistry::createServicesFromConfig(std::move(params)));
0052   }
0053 }  // namespace
0054 
0055 TEST_CASE("Tests of ROCmService", "[ROCmService]") {
0056   // Test setup: check if a simple ROCm runtime API call fails:
0057   // if so, skip the test with the ROCmService enabled
0058   int deviceCount = 0;
0059   auto ret = hipGetDeviceCount(&deviceCount);
0060 
0061   if (ret != hipSuccess) {
0062     WARN("Unable to query the ROCm capable devices from the ROCm runtime API: ("
0063          << ret << ") " << hipGetErrorString(ret) << ". Running only tests not requiring devices.");
0064   }
0065 
0066   std::string config = makeProcess("Test");
0067   addROCmService(config);
0068   auto serviceToken = getServiceToken(config);
0069   edm::ServiceRegistry::Operate operate(serviceToken);
0070 
0071   SECTION("Enable the ROCmService only if there are ROCm capable GPUs") {
0072     edm::Service<ROCmInterface> service;
0073     if (deviceCount <= 0) {
0074       REQUIRE((not service or not service->enabled()));
0075       WARN("ROCmService is not present, or disabled because there are no ROCm GPU devices");
0076       return;
0077     } else {
0078       REQUIRE(service);
0079       REQUIRE(service->enabled());
0080       INFO("ROCmService is enabled");
0081     }
0082   }
0083 
0084   SECTION("ROCmService enabled") {
0085     int driverVersion = 0, runtimeVersion = 0;
0086     edm::Service<ROCmInterface> service;
0087     ret = hipDriverGetVersion(&driverVersion);
0088     if (ret != hipSuccess) {
0089       FAIL("Unable to query the ROCm driver version from the ROCm runtime API: (" << ret << ") "
0090                                                                                   << hipGetErrorString(ret));
0091     }
0092     ret = hipRuntimeGetVersion(&runtimeVersion);
0093     if (ret != hipSuccess) {
0094       FAIL("Unable to query the ROCm runtime API version: (" << ret << ") " << hipGetErrorString(ret));
0095     }
0096 
0097     SECTION("ROCm Queries") {
0098       WARN("ROCm Driver Version / Runtime Version: " << driverVersion / 1000 << "." << (driverVersion % 100) / 10
0099                                                      << " / " << runtimeVersion / 1000 << "."
0100                                                      << (runtimeVersion % 100) / 10);
0101 
0102       // Test that the number of devices found by the service
0103       // is the same as detected by the ROCm runtime API
0104       REQUIRE(service->numberOfDevices() == deviceCount);
0105       WARN("Detected " << service->numberOfDevices() << " ROCm Capable device(s)");
0106 
0107       // Test that the compute capabilities of each device
0108       // are the same as detected by the ROCm runtime API
0109       for (int i = 0; i < deviceCount; ++i) {
0110         hipDeviceProp_t deviceProp;
0111         ret = hipGetDeviceProperties(&deviceProp, i);
0112         if (ret != hipSuccess) {
0113           FAIL("Unable to query the ROCm properties for device " << i << " from the ROCm runtime API: (" << ret << ") "
0114                                                                  << hipGetErrorString(ret));
0115         }
0116 
0117         REQUIRE(deviceProp.major == service->computeCapability(i).first);
0118         REQUIRE(deviceProp.minor == service->computeCapability(i).second);
0119         INFO("Device " << i << ": " << deviceProp.name << "\n ROCm Capability Major/Minor version number: "
0120                        << deviceProp.major << "." << deviceProp.minor);
0121       }
0122     }
0123 
0124     SECTION("With ResourceInformationService available") {
0125       std::string config = makeProcess("Test");
0126       addResourceInformationService(config);
0127       addROCmService(config);
0128       auto serviceToken = getServiceToken(config);
0129       edm::ServiceRegistry::Operate operate(serviceToken);
0130 
0131       edm::Service<ROCmInterface> service;
0132       REQUIRE(service);
0133       REQUIRE(service->enabled());
0134       edm::Service<edm::ResourceInformation> ri;
0135       REQUIRE(ri->gpuModels().size() > 0);
0136       /*
0137       REQUIRE(ri->amdDriverVersion().size() > 0);
0138       REQUIRE(ri->rocmDriverVersion() == driverVersion);
0139       REQUIRE(ri->rocmRuntimeVersion() == runtimeVersion);
0140       */
0141     }
0142   }
0143 
0144   SECTION("Force to be disabled") {
0145     std::string config = makeProcess("Test");
0146     addROCmService(config, false);
0147     auto serviceToken = getServiceToken(config);
0148     edm::ServiceRegistry::Operate operate(serviceToken);
0149 
0150     edm::Service<ROCmInterface> service;
0151     REQUIRE(service->enabled() == false);
0152     REQUIRE(service->numberOfDevices() == 0);
0153   }
0154 }