File indexing completed on 2024-04-06 12:15:43
0001 import FWCore.ParameterSet.Config as cms
0002 import sys
0003 import argparse
0004
0005
0006
0007
0008
0009
0010 parser = argparse.ArgumentParser(prog=sys.argv[0], description='Test various Alpaka module types')
0011
0012 parser.add_argument("--expectBackend", type=str, help="Expect this backend to run")
0013 parser.add_argument("--run", type=int, help="Run number (default: 1)", default=1)
0014
0015 args = parser.parse_args()
0016
0017 process = cms.Process('TEST')
0018
0019 process.source = cms.Source('EmptySource',
0020 firstRun = cms.untracked.uint32(args.run)
0021 )
0022
0023 process.maxEvents.input = 10
0024
0025 process.load('Configuration.StandardSequences.Accelerators_cff')
0026 process.load('HeterogeneousCore.AlpakaCore.ProcessAcceleratorAlpaka_cfi')
0027
0028 process.alpakaESRecordASource = cms.ESSource("EmptyESSource",
0029 recordName = cms.string('AlpakaESTestRecordA'),
0030 iovIsRunNotTime = cms.bool(True),
0031 firstValid = cms.vuint32(1)
0032 )
0033
0034 process.esProducerA = cms.ESProducer("cms::alpakatest::TestESProducerA", value = cms.int32(42))
0035
0036 process.alpakaESProducerA = cms.ESProducer("TestAlpakaESProducerA@alpaka")
0037
0038
0039 process.producer = cms.EDProducer("TestAlpakaGlobalProducerOffset@alpaka",
0040 xvalue = cms.PSet(
0041 alpaka_serial_sync = cms.double(1.0),
0042 alpaka_cuda_async = cms.double(2.0),
0043 alpaka_rocm_async = cms.double(3.0),
0044 )
0045 )
0046 process.producerHost = process.producer.clone(
0047 alpaka = cms.untracked.PSet(
0048 backend = cms.untracked.string("serial_sync")
0049 )
0050 )
0051
0052 process.compare = cms.EDAnalyzer("TestAlpakaHostDeviceCompare",
0053 srcHost = cms.untracked.InputTag("producerHost"),
0054 srcDevice = cms.untracked.InputTag("producer"),
0055 expectedXdiff = cms.untracked.double(0.0)
0056 )
0057 if args.expectBackend == "cuda_async":
0058 process.compare.expectedXdiff = -1.0
0059 elif args.expectBackend == "rocm_async":
0060 process.compare.expectedXdiff = -2.0
0061
0062 process.t = cms.Task(process.producer, process.producerHost)
0063 process.p = cms.Path(process.compare, process.t)