Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:43

0001 import FWCore.ParameterSet.Config as cms
0002 
0003 import os
0004 
0005 from HeterogeneousCore.Common.PlatformStatus import PlatformStatus
0006 
0007 class ProcessAcceleratorCUDA(cms.ProcessAccelerator):
0008     def __init__(self):
0009         super(ProcessAcceleratorCUDA, self).__init__()
0010         self._label = "gpu-nvidia"
0011 
0012     def labels(self):
0013         return [ self._label ]
0014 
0015     def enabledLabels(self):
0016         # Check if CUDA is available, and if the system has at least one usable device.
0017         # These should be checked on each worker node, because it depends both
0018         # on the architecture and on the actual hardware present in the machine.
0019         status = PlatformStatus(os.waitstatus_to_exitcode(os.system("cudaIsEnabled")))
0020         return self.labels() if status == PlatformStatus.Success else []
0021 
0022     def apply(self, process, accelerators):
0023 
0024         if self._label in accelerators:
0025             # Ensure that the CUDAService is loaded
0026             if not hasattr(process, "CUDAService"):
0027                 from HeterogeneousCore.CUDAServices.CUDAService_cfi import CUDAService
0028                 process.add_(CUDAService)
0029 
0030             # Propagate the CUDAService messages through the MessageLogger
0031             if not hasattr(process.MessageLogger, "CUDAService"):
0032                 process.MessageLogger.CUDAService = cms.untracked.PSet()
0033 
0034         else:
0035             # Make sure the CUDAService is not loaded
0036             if hasattr(process, "CUDAService"):
0037                 del process.CUDAService
0038 
0039             # Drop the CUDAService messages from the MessageLogger
0040             if hasattr(process.MessageLogger, "CUDAService"):
0041                 del process.MessageLogger.CUDAService
0042 
0043 
0044 # Ensure this module is kept in the configuration when dumping it
0045 cms.specialImportRegistry.registerSpecialImportForType(ProcessAcceleratorCUDA, "from HeterogeneousCore.CUDACore.ProcessAcceleratorCUDA import ProcessAcceleratorCUDA")