File indexing completed on 2024-08-06 22:36:41
0001 import FWCore.ParameterSet.Config as cms
0002 import os, sys, json
0003 from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
0004
0005
0006 models = {
0007 "TritonImageProducer": ["inception_graphdef", "densenet_onnx"],
0008 "TritonGraphProducer": ["gat_test"],
0009 "TritonGraphFilter": ["gat_test"],
0010 "TritonGraphAnalyzer": ["gat_test"],
0011 "TritonIdentityProducer": ["ragged_io"],
0012 }
0013
0014
0015 allowed_modes = ["Async","PseudoAsync","Sync"]
0016 allowed_compression = ["none","deflate","gzip"]
0017 allowed_devices = ["auto","cpu","gpu"]
0018 allowed_containers = ["apptainer","docker","podman","podman-hpc"]
0019
0020 parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
0021 parser.add_argument("--maxEvents", default=-1, type=int, help="Number of events to process (-1 for all)")
0022 parser.add_argument("--serverName", default="default", type=str, help="name for server (used internally)")
0023 parser.add_argument("--address", default="", type=str, help="server address")
0024 parser.add_argument("--port", default=8001, type=int, help="server port")
0025 parser.add_argument("--timeout", default=30, type=int, help="timeout for requests")
0026 parser.add_argument("--timeoutUnit", default="seconds", type=str, help="unit for timeout")
0027 parser.add_argument("--params", default="", type=str, help="json file containing server address/port")
0028 parser.add_argument("--threads", default=1, type=int, help="number of threads")
0029 parser.add_argument("--streams", default=0, type=int, help="number of streams")
0030 parser.add_argument("--modules", metavar=("MODULES"), default=["TritonGraphProducer"], nargs='+', type=str, choices=list(models), help="list of modules to run (choices: %(choices)s)")
0031 parser.add_argument("--models", default=["gat_test"], nargs='+', type=str, help="list of models (same length as modules, or just 1 entry if all modules use same model)")
0032 parser.add_argument("--mode", default="Async", type=str, choices=allowed_modes, help="mode for client")
0033 parser.add_argument("--verbose", default=False, action="store_true", help="enable all verbose output")
0034 parser.add_argument("--verboseClient", default=False, action="store_true", help="enable verbose output for clients")
0035 parser.add_argument("--verboseServer", default=False, action="store_true", help="enable verbose output for server")
0036 parser.add_argument("--verboseService", default=False, action="store_true", help="enable verbose output for TritonService")
0037 parser.add_argument("--verboseDiscovery", default=False, action="store_true", help="enable verbose output just for server discovery in TritonService")
0038 parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules")
0039 parser.add_argument("--fallbackName", default="", type=str, help="name for fallback server")
0040 parser.add_argument("--unittest", default=False, action="store_true", help="unit test mode: reduce input sizes")
0041 parser.add_argument("--testother", default=False, action="store_true", help="also test gRPC communication if shared memory enabled, or vice versa")
0042 parser.add_argument("--noShm", default=False, action="store_true", help="disable shared memory")
0043 parser.add_argument("--compression", default="", type=str, choices=allowed_compression, help="enable I/O compression")
0044 parser.add_argument("--ssl", default=False, action="store_true", help="enable SSL authentication for server communication")
0045 parser.add_argument("--device", default="auto", type=str.lower, choices=allowed_devices, help="specify device for fallback server")
0046 parser.add_argument("--container", default="apptainer", type=str.lower, choices=allowed_containers, help="specify container for fallback server")
0047 parser.add_argument("--tries", default=0, type=int, help="number of retries for failed request")
0048 options = parser.parse_args()
0049
0050 if len(options.params)>0:
0051 with open(options.params,'r') as pfile:
0052 pdict = json.load(pfile)
0053 options.address = pdict["address"]
0054 options.port = int(pdict["port"])
0055 print("server = "+options.address+":"+str(options.port))
0056
0057
0058 if len(options.modules)!=len(options.models):
0059
0060 if len(options.models)==1: options.models = [options.models[0]]*(len(options.modules))
0061 else: raise ValueError("Arguments for modules and models must have same length")
0062 for im,module in enumerate(options.modules):
0063 model = options.models[im]
0064 if model not in models[module]:
0065 raise ValueError("Unsupported model {} for module {}".format(model,module))
0066
0067 from Configuration.ProcessModifiers.enableSonicTriton_cff import enableSonicTriton
0068 process = cms.Process('tritonTest',enableSonicTriton)
0069
0070 process.load("HeterogeneousCore.SonicTriton.TritonService_cff")
0071
0072 process.maxEvents = cms.untracked.PSet( input = cms.untracked.int32(options.maxEvents) )
0073
0074 process.source = cms.Source("EmptySource")
0075
0076 process.TritonService.verbose = options.verbose or options.verboseService or options.verboseDiscovery
0077 process.TritonService.fallback.verbose = options.verbose or options.verboseServer
0078 process.TritonService.fallback.container = options.container
0079 process.TritonService.fallback.device = options.device
0080 if len(options.fallbackName)>0:
0081 process.TritonService.fallback.instanceBaseName = options.fallbackName
0082 if len(options.address)>0:
0083 process.TritonService.servers.append(
0084 cms.PSet(
0085 name = cms.untracked.string(options.serverName),
0086 address = cms.untracked.string(options.address),
0087 port = cms.untracked.uint32(options.port),
0088 useSsl = cms.untracked.bool(options.ssl),
0089 rootCertificates = cms.untracked.string(""),
0090 privateKey = cms.untracked.string(""),
0091 certificateChain = cms.untracked.string(""),
0092 )
0093 )
0094
0095
0096 process.p = cms.Path()
0097
0098 modules = {
0099 "Producer": cms.EDProducer,
0100 "Filter": cms.EDFilter,
0101 "Analyzer": cms.EDAnalyzer,
0102 }
0103
0104 keepMsgs = []
0105 if options.verbose or options.verboseDiscovery:
0106 keepMsgs.append('TritonDiscovery')
0107 if options.verbose or options.verboseClient:
0108 keepMsgs.append('TritonClient')
0109 if options.verbose or options.verboseService:
0110 keepMsgs.append('TritonService')
0111
0112 for im,module in enumerate(options.modules):
0113 model = options.models[im]
0114 Module = [obj for name,obj in modules.items() if name in module][0]
0115 setattr(process, module,
0116 Module(module,
0117 Client = cms.PSet(
0118 mode = cms.string(options.mode),
0119 preferredServer = cms.untracked.string(""),
0120 timeout = cms.untracked.uint32(options.timeout),
0121 timeoutUnit = cms.untracked.string(options.timeoutUnit),
0122 modelName = cms.string(model),
0123 modelVersion = cms.string(""),
0124 modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)),
0125 verbose = cms.untracked.bool(options.verbose or options.verboseClient),
0126 allowedTries = cms.untracked.uint32(options.tries),
0127 useSharedMemory = cms.untracked.bool(not options.noShm),
0128 compression = cms.untracked.string(options.compression),
0129 )
0130 )
0131 )
0132 processModule = getattr(process, module)
0133 if module=="TritonImageProducer":
0134 processModule.batchSize = cms.int32(1)
0135 processModule.topN = cms.uint32(5)
0136 processModule.imageList = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/{}_labels.txt".format(model,model.split('_')[0]))
0137 elif "TritonGraph" in module:
0138 if options.unittest:
0139
0140 processModule.nodeMin = cms.uint32(1)
0141 processModule.nodeMax = cms.uint32(10)
0142 processModule.edgeMin = cms.uint32(20)
0143 processModule.edgeMax = cms.uint32(40)
0144 else:
0145 processModule.nodeMin = cms.uint32(100)
0146 processModule.nodeMax = cms.uint32(4000)
0147 processModule.edgeMin = cms.uint32(8000)
0148 processModule.edgeMax = cms.uint32(15000)
0149 processModule.brief = cms.bool(options.brief)
0150 process.p += processModule
0151 if options.verbose or options.verboseClient:
0152 keepMsgs.extend([module,module+':TritonClient'])
0153 if options.testother:
0154
0155 _module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM"
0156 setattr(process, _module2,
0157 processModule.clone(
0158 Client = dict(useSharedMemory = not processModule.Client.useSharedMemory)
0159 )
0160 )
0161 processModule2 = getattr(process, _module2)
0162 process.p += processModule2
0163 if options.verbose or options.verboseClient:
0164 keepMsgs.extend([_module2,_module2+':TritonClient'])
0165
0166 process.load('FWCore/MessageService/MessageLogger_cfi')
0167 process.MessageLogger.cerr.FwkReport.reportEvery = 500
0168 for msg in keepMsgs:
0169 setattr(process.MessageLogger.cerr,msg,
0170 cms.untracked.PSet(
0171 limit = cms.untracked.int32(10000000),
0172 )
0173 )
0174
0175 if options.threads>0:
0176 process.options.numberOfThreads = options.threads
0177 process.options.numberOfStreams = options.streams
0178