File indexing completed on 2024-04-06 12:15:43
0001 import FWCore.ParameterSet.Config as cms
0002
0003 def _switch_cuda(useAccelerators):
0004 have_gpu = ("gpu-nvidia" in useAccelerators)
0005 return (have_gpu, 2)
0006
0007 class SwitchProducerCUDA(cms.SwitchProducer):
0008 def __init__(self, **kargs):
0009 super(SwitchProducerCUDA,self).__init__(
0010 dict(cpu = cms.SwitchProducer.getCpu(),
0011 cuda = _switch_cuda),
0012 **kargs
0013 )
0014 cms.specialImportRegistry.registerSpecialImportForType(SwitchProducerCUDA, "from HeterogeneousCore.CUDACore.SwitchProducerCUDA import SwitchProducerCUDA")
0015
0016 if __name__ == "__main__":
0017 import unittest
0018
0019 class TestSwitchProducerCUDA(unittest.TestCase):
0020 def testPickle(self):
0021 import pickle
0022 sp = SwitchProducerCUDA(cpu = cms.EDProducer("Foo"), cuda = cms.EDProducer("Bar"))
0023 pkl = pickle.dumps(sp)
0024 unpkl = pickle.loads(pkl)
0025 self.assertEqual(unpkl.cpu.type_(), "Foo")
0026 self.assertEqual(unpkl.cuda.type_(), "Bar")
0027
0028 unittest.main()
0029