File indexing completed on 2024-04-06 12:15:48
0001 import FWCore.ParameterSet.Config as cms
0002 import os, sys, json
0003 from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
0004
0005
0006 models = {
0007 "TritonImageProducer": ["inception_graphdef", "densenet_onnx"],
0008 "TritonGraphProducer": ["gat_test"],
0009 "TritonGraphFilter": ["gat_test"],
0010 "TritonGraphAnalyzer": ["gat_test"],
0011 "TritonIdentityProducer": ["ragged_io"],
0012 }
0013
0014
0015 allowed_modes = ["Async","PseudoAsync","Sync"]
0016 allowed_compression = ["none","deflate","gzip"]
0017 allowed_devices = ["auto","cpu","gpu"]
0018
0019 parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
0020 parser.add_argument("--maxEvents", default=-1, type=int, help="Number of events to process (-1 for all)")
0021 parser.add_argument("--serverName", default="default", type=str, help="name for server (used internally)")
0022 parser.add_argument("--address", default="", type=str, help="server address")
0023 parser.add_argument("--port", default=8001, type=int, help="server port")
0024 parser.add_argument("--timeout", default=30, type=int, help="timeout for requests")
0025 parser.add_argument("--timeoutUnit", default="seconds", type=str, help="unit for timeout")
0026 parser.add_argument("--params", default="", type=str, help="json file containing server address/port")
0027 parser.add_argument("--threads", default=1, type=int, help="number of threads")
0028 parser.add_argument("--streams", default=0, type=int, help="number of streams")
0029 parser.add_argument("--modules", metavar=("MODULES"), default=["TritonGraphProducer"], nargs='+', type=str, choices=list(models), help="list of modules to run (choices: %(choices)s)")
0030 parser.add_argument("--models", default=["gat_test"], nargs='+', type=str, help="list of models (same length as modules, or just 1 entry if all modules use same model)")
0031 parser.add_argument("--mode", default="Async", type=str, choices=allowed_modes, help="mode for client")
0032 parser.add_argument("--verbose", default=False, action="store_true", help="enable all verbose output")
0033 parser.add_argument("--verboseClient", default=False, action="store_true", help="enable verbose output for clients")
0034 parser.add_argument("--verboseServer", default=False, action="store_true", help="enable verbose output for server")
0035 parser.add_argument("--verboseService", default=False, action="store_true", help="enable verbose output for TritonService")
0036 parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules")
0037 parser.add_argument("--fallbackName", default="", type=str, help="name for fallback server")
0038 parser.add_argument("--unittest", default=False, action="store_true", help="unit test mode: reduce input sizes")
0039 parser.add_argument("--testother", default=False, action="store_true", help="also test gRPC communication if shared memory enabled, or vice versa")
0040 parser.add_argument("--noShm", default=False, action="store_true", help="disable shared memory")
0041 parser.add_argument("--compression", default="", type=str, choices=allowed_compression, help="enable I/O compression")
0042 parser.add_argument("--ssl", default=False, action="store_true", help="enable SSL authentication for server communication")
0043 parser.add_argument("--device", default="auto", type=str.lower, choices=allowed_devices, help="specify device for fallback server")
0044 parser.add_argument("--docker", default=False, action="store_true", help="use Docker for fallback server")
0045 parser.add_argument("--tries", default=0, type=int, help="number of retries for failed request")
0046 options = parser.parse_args()
0047
0048 if len(options.params)>0:
0049 with open(options.params,'r') as pfile:
0050 pdict = json.load(pfile)
0051 options.address = pdict["address"]
0052 options.port = int(pdict["port"])
0053 print("server = "+options.address+":"+str(options.port))
0054
0055
0056 if len(options.modules)!=len(options.models):
0057
0058 if len(options.models)==1: options.models = [options.models[0]]*(len(options.modules))
0059 else: raise ValueError("Arguments for modules and models must have same length")
0060 for im,module in enumerate(options.modules):
0061 model = options.models[im]
0062 if model not in models[module]:
0063 raise ValueError("Unsupported model {} for module {}".format(model,module))
0064
0065 from Configuration.ProcessModifiers.enableSonicTriton_cff import enableSonicTriton
0066 process = cms.Process('tritonTest',enableSonicTriton)
0067
0068 process.load("HeterogeneousCore.SonicTriton.TritonService_cff")
0069
0070 process.maxEvents = cms.untracked.PSet( input = cms.untracked.int32(options.maxEvents) )
0071
0072 process.source = cms.Source("EmptySource")
0073
0074 process.TritonService.verbose = options.verbose or options.verboseService
0075 process.TritonService.fallback.verbose = options.verbose or options.verboseServer
0076 process.TritonService.fallback.useDocker = options.docker
0077 if len(options.fallbackName)>0:
0078 process.TritonService.fallback.instanceBaseName = options.fallbackName
0079 if options.device != "auto":
0080 process.TritonService.fallback.useGPU = options.device=="gpu"
0081 if len(options.address)>0:
0082 process.TritonService.servers.append(
0083 cms.PSet(
0084 name = cms.untracked.string(options.serverName),
0085 address = cms.untracked.string(options.address),
0086 port = cms.untracked.uint32(options.port),
0087 useSsl = cms.untracked.bool(options.ssl),
0088 rootCertificates = cms.untracked.string(""),
0089 privateKey = cms.untracked.string(""),
0090 certificateChain = cms.untracked.string(""),
0091 )
0092 )
0093
0094
0095 process.p = cms.Path()
0096
0097 modules = {
0098 "Producer": cms.EDProducer,
0099 "Filter": cms.EDFilter,
0100 "Analyzer": cms.EDAnalyzer,
0101 }
0102
0103 keepMsgs = ['TritonClient','TritonService']
0104
0105 for im,module in enumerate(options.modules):
0106 model = options.models[im]
0107 Module = [obj for name,obj in modules.items() if name in module][0]
0108 setattr(process, module,
0109 Module(module,
0110 Client = cms.PSet(
0111 mode = cms.string(options.mode),
0112 preferredServer = cms.untracked.string(""),
0113 timeout = cms.untracked.uint32(options.timeout),
0114 timeoutUnit = cms.untracked.string(options.timeoutUnit),
0115 modelName = cms.string(model),
0116 modelVersion = cms.string(""),
0117 modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)),
0118 verbose = cms.untracked.bool(options.verbose or options.verboseClient),
0119 allowedTries = cms.untracked.uint32(options.tries),
0120 useSharedMemory = cms.untracked.bool(not options.noShm),
0121 compression = cms.untracked.string(options.compression),
0122 )
0123 )
0124 )
0125 processModule = getattr(process, module)
0126 if module=="TritonImageProducer":
0127 processModule.batchSize = cms.int32(1)
0128 processModule.topN = cms.uint32(5)
0129 processModule.imageList = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/{}_labels.txt".format(model,model.split('_')[0]))
0130 elif "TritonGraph" in module:
0131 if options.unittest:
0132
0133 processModule.nodeMin = cms.uint32(1)
0134 processModule.nodeMax = cms.uint32(10)
0135 processModule.edgeMin = cms.uint32(20)
0136 processModule.edgeMax = cms.uint32(40)
0137 else:
0138 processModule.nodeMin = cms.uint32(100)
0139 processModule.nodeMax = cms.uint32(4000)
0140 processModule.edgeMin = cms.uint32(8000)
0141 processModule.edgeMax = cms.uint32(15000)
0142 processModule.brief = cms.bool(options.brief)
0143 process.p += processModule
0144 keepMsgs.extend([module,module+':TritonClient'])
0145 if options.testother:
0146
0147 _module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM"
0148 setattr(process, _module2,
0149 processModule.clone(
0150 Client = dict(useSharedMemory = not processModule.Client.useSharedMemory)
0151 )
0152 )
0153 processModule2 = getattr(process, _module2)
0154 process.p += processModule2
0155 keepMsgs.extend([_module2,_module2+':TritonClient'])
0156
0157 process.load('FWCore/MessageService/MessageLogger_cfi')
0158 process.MessageLogger.cerr.FwkReport.reportEvery = 500
0159 for msg in keepMsgs:
0160 setattr(process.MessageLogger.cerr,msg,
0161 cms.untracked.PSet(
0162 limit = cms.untracked.int32(10000000),
0163 )
0164 )
0165
0166 if options.threads>0:
0167 process.options.numberOfThreads = options.threads
0168 process.options.numberOfStreams = options.streams
0169