Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-04-11 03:31:18

0001 import FWCore.ParameterSet.Config as cms
0002 import os, sys, json
0003 from HeterogeneousCore.SonicTriton.customize import getDefaultClientPSet, getParser, getOptions, applyOptions, applyClientOptions
0004 
0005 # module/model correspondence
0006 models = {
0007     "TritonImageProducer": ["inception_graphdef", "densenet_onnx"],
0008     "TritonGraphProducer": ["gat_test"],
0009     "TritonGraphFilter": ["gat_test"],
0010     "TritonGraphAnalyzer": ["gat_test"],
0011     "TritonIdentityProducer": ["ragged_io"],
0012 }
0013 
0014 # other choices
0015 allowed_modes = ["Async","PseudoAsync","Sync"]
0016 
0017 parser = getParser()
0018 parser.add_argument("--modules", metavar=("MODULES"), default=["TritonGraphProducer"], nargs='+', type=str, choices=list(models), help="list of modules to run (choices: %(choices)s)")
0019 parser.add_argument("--models", default=["gat_test"], nargs='+', type=str, help="list of models (same length as modules, or just 1 entry if all modules use same model)")
0020 parser.add_argument("--mode", default="Async", type=str, choices=allowed_modes, help="mode for client")
0021 parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules")
0022 parser.add_argument("--unittest", default=False, action="store_true", help="unit test mode: reduce input sizes")
0023 parser.add_argument("--testother", default=False, action="store_true", help="also test gRPC communication if shared memory enabled, or vice versa")
0024 
0025 options = getOptions(parser, verbose=True)
0026 
0027 # check models and modules
0028 if len(options.modules)!=len(options.models):
0029     # assigning to VarParsing.multiplicity.list actually appends to existing value(s)
0030     if len(options.models)==1: options.models = [options.models[0]]*(len(options.modules))
0031     else: raise ValueError("Arguments for modules and models must have same length")
0032 for im,module in enumerate(options.modules):
0033     model = options.models[im]
0034     if model not in models[module]:
0035         raise ValueError("Unsupported model {} for module {}".format(model,module))
0036 
0037 from Configuration.ProcessModifiers.enableSonicTriton_cff import enableSonicTriton
0038 process = cms.Process('tritonTest',enableSonicTriton)
0039 
0040 process.load("HeterogeneousCore.SonicTriton.TritonService_cff")
0041 process.source = cms.Source("EmptySource")
0042 
0043 # Let it run
0044 process.p = cms.Path()
0045 
0046 modules = {
0047     "Producer": cms.EDProducer,
0048     "Filter": cms.EDFilter,
0049     "Analyzer": cms.EDAnalyzer,
0050 }
0051 
0052 defaultClient = applyClientOptions(getDefaultClientPSet().clone(), options)
0053 
0054 for im,module in enumerate(options.modules):
0055     model = options.models[im]
0056     Module = [obj for name,obj in modules.items() if name in module][0]
0057     setattr(process, module,
0058         Module(module,
0059             Client = defaultClient.clone(
0060                 mode = cms.string(options.mode),
0061                 preferredServer = cms.untracked.string(""),
0062                 modelName = cms.string(model),
0063                 modelVersion = cms.string(""),
0064                 modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)),
0065             )
0066         )
0067     )
0068     processModule = getattr(process, module)
0069     if module=="TritonImageProducer":
0070         processModule.batchSize = cms.int32(1)
0071         processModule.topN = cms.uint32(5)
0072         processModule.imageList = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/{}_labels.txt".format(model,model.split('_')[0]))
0073     elif "TritonGraph" in module:
0074         if options.unittest:
0075             # reduce input size for unit test
0076             processModule.nodeMin = cms.uint32(1)
0077             processModule.nodeMax = cms.uint32(10)
0078             processModule.edgeMin = cms.uint32(20)
0079             processModule.edgeMax = cms.uint32(40)
0080         else:
0081             processModule.nodeMin = cms.uint32(100)
0082             processModule.nodeMax = cms.uint32(4000)
0083             processModule.edgeMin = cms.uint32(8000)
0084             processModule.edgeMax = cms.uint32(15000)
0085         processModule.brief = cms.bool(options.brief)
0086     process.p += processModule
0087     if options.testother:
0088         # clone modules to test both gRPC and shared memory
0089         _module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM"
0090         setattr(process, _module2,
0091             processModule.clone(
0092                 Client = dict(useSharedMemory = not processModule.Client.useSharedMemory)
0093             )
0094         )
0095         processModule2 = getattr(process, _module2)
0096         process.p += processModule2
0097 
0098 process = applyOptions(process, options)