File indexing completed on 2025-04-11 03:31:18
0001 import FWCore.ParameterSet.Config as cms
0002 import os, sys, json
0003 from HeterogeneousCore.SonicTriton.customize import getDefaultClientPSet, getParser, getOptions, applyOptions, applyClientOptions
0004
0005
0006 models = {
0007 "TritonImageProducer": ["inception_graphdef", "densenet_onnx"],
0008 "TritonGraphProducer": ["gat_test"],
0009 "TritonGraphFilter": ["gat_test"],
0010 "TritonGraphAnalyzer": ["gat_test"],
0011 "TritonIdentityProducer": ["ragged_io"],
0012 }
0013
0014
0015 allowed_modes = ["Async","PseudoAsync","Sync"]
0016
0017 parser = getParser()
0018 parser.add_argument("--modules", metavar=("MODULES"), default=["TritonGraphProducer"], nargs='+', type=str, choices=list(models), help="list of modules to run (choices: %(choices)s)")
0019 parser.add_argument("--models", default=["gat_test"], nargs='+', type=str, help="list of models (same length as modules, or just 1 entry if all modules use same model)")
0020 parser.add_argument("--mode", default="Async", type=str, choices=allowed_modes, help="mode for client")
0021 parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules")
0022 parser.add_argument("--unittest", default=False, action="store_true", help="unit test mode: reduce input sizes")
0023 parser.add_argument("--testother", default=False, action="store_true", help="also test gRPC communication if shared memory enabled, or vice versa")
0024
0025 options = getOptions(parser, verbose=True)
0026
0027
0028 if len(options.modules)!=len(options.models):
0029
0030 if len(options.models)==1: options.models = [options.models[0]]*(len(options.modules))
0031 else: raise ValueError("Arguments for modules and models must have same length")
0032 for im,module in enumerate(options.modules):
0033 model = options.models[im]
0034 if model not in models[module]:
0035 raise ValueError("Unsupported model {} for module {}".format(model,module))
0036
0037 from Configuration.ProcessModifiers.enableSonicTriton_cff import enableSonicTriton
0038 process = cms.Process('tritonTest',enableSonicTriton)
0039
0040 process.load("HeterogeneousCore.SonicTriton.TritonService_cff")
0041 process.source = cms.Source("EmptySource")
0042
0043
0044 process.p = cms.Path()
0045
0046 modules = {
0047 "Producer": cms.EDProducer,
0048 "Filter": cms.EDFilter,
0049 "Analyzer": cms.EDAnalyzer,
0050 }
0051
0052 defaultClient = applyClientOptions(getDefaultClientPSet().clone(), options)
0053
0054 for im,module in enumerate(options.modules):
0055 model = options.models[im]
0056 Module = [obj for name,obj in modules.items() if name in module][0]
0057 setattr(process, module,
0058 Module(module,
0059 Client = defaultClient.clone(
0060 mode = cms.string(options.mode),
0061 preferredServer = cms.untracked.string(""),
0062 modelName = cms.string(model),
0063 modelVersion = cms.string(""),
0064 modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)),
0065 )
0066 )
0067 )
0068 processModule = getattr(process, module)
0069 if module=="TritonImageProducer":
0070 processModule.batchSize = cms.int32(1)
0071 processModule.topN = cms.uint32(5)
0072 processModule.imageList = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/{}_labels.txt".format(model,model.split('_')[0]))
0073 elif "TritonGraph" in module:
0074 if options.unittest:
0075
0076 processModule.nodeMin = cms.uint32(1)
0077 processModule.nodeMax = cms.uint32(10)
0078 processModule.edgeMin = cms.uint32(20)
0079 processModule.edgeMax = cms.uint32(40)
0080 else:
0081 processModule.nodeMin = cms.uint32(100)
0082 processModule.nodeMax = cms.uint32(4000)
0083 processModule.edgeMin = cms.uint32(8000)
0084 processModule.edgeMax = cms.uint32(15000)
0085 processModule.brief = cms.bool(options.brief)
0086 process.p += processModule
0087 if options.testother:
0088
0089 _module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM"
0090 setattr(process, _module2,
0091 processModule.clone(
0092 Client = dict(useSharedMemory = not processModule.Client.useSharedMemory)
0093 )
0094 )
0095 processModule2 = getattr(process, _module2)
0096 process.p += processModule2
0097
0098 process = applyOptions(process, options)