Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:46

0001 import FWCore.ParameterSet.Config as cms
0002 
0003 import os
0004 
0005 from HeterogeneousCore.Common.PlatformStatus import PlatformStatus
0006 
0007 class ProcessAcceleratorROCm(cms.ProcessAccelerator):
0008     def __init__(self):
0009         super(ProcessAcceleratorROCm, self).__init__()
0010         self._label = "gpu-amd"
0011 
0012     def labels(self):
0013         return [ self._label ]
0014 
0015     def enabledLabels(self):
0016         # Check if ROCm is available, and if the system has at least one usable device.
0017         # These should be checked on each worker node, because it depends both
0018         # on the architecture and on the actual hardware present in the machine.
0019         status = PlatformStatus(os.waitstatus_to_exitcode(os.system("rocmIsEnabled")))
0020         return self.labels() if status == PlatformStatus.Success else []
0021 
0022     def apply(self, process, accelerators):
0023 
0024         if self._label in accelerators:
0025             # Ensure that the ROCmService is loaded
0026             if not hasattr(process, "ROCmService"):
0027                 from HeterogeneousCore.ROCmServices.ROCmService_cfi import ROCmService
0028                 process.add_(ROCmService)
0029 
0030             # Propagate the ROCmService messages through the MessageLogger
0031             if not hasattr(process.MessageLogger, "ROCmService"):
0032                 process.MessageLogger.ROCmService = cms.untracked.PSet()
0033 
0034         else:
0035             # Make sure the ROCmService is not loaded
0036             if hasattr(process, "ROCmService"):
0037                 del process.ROCmService
0038 
0039             # Drop the ROCmService messages from the MessageLogger
0040             if hasattr(process.MessageLogger, "ROCmService"):
0041                 del process.MessageLogger.ROCmService
0042 
0043 
0044 # Ensure this module is kept in the configuration when dumping it
0045 cms.specialImportRegistry.registerSpecialImportForType(ProcessAcceleratorROCm, "from HeterogeneousCore.ROCmCore.ProcessAcceleratorROCm import ProcessAcceleratorROCm")