File indexing completed on 2022-02-25 02:40:55
0001 from FWCore.ParameterSet.VarParsing import VarParsing
0002 import FWCore.ParameterSet.Config as cms
0003 import os, sys, json
0004
0005
0006 models = {
0007 "TritonImageProducer": ["inception_graphdef", "densenet_onnx"],
0008 "TritonGraphProducer": ["gat_test"],
0009 "TritonGraphFilter": ["gat_test"],
0010 "TritonGraphAnalyzer": ["gat_test"],
0011 }
0012
0013
0014 allowed_modes = ["Async","PseudoAsync","Sync"]
0015 allowed_compression = ["none","deflate","gzip"]
0016 allowed_devices = ["auto","cpu","gpu"]
0017
0018 options = VarParsing()
0019 options.register("maxEvents", -1, VarParsing.multiplicity.singleton, VarParsing.varType.int, "Number of events to process (-1 for all)")
0020 options.register("serverName", "default", VarParsing.multiplicity.singleton, VarParsing.varType.string, "name for server (used internally)")
0021 options.register("address", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "server address")
0022 options.register("port", 8001, VarParsing.multiplicity.singleton, VarParsing.varType.int, "server port")
0023 options.register("timeout", 30, VarParsing.multiplicity.singleton, VarParsing.varType.int, "timeout for requests")
0024 options.register("params", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "json file containing server address/port")
0025 options.register("threads", 1, VarParsing.multiplicity.singleton, VarParsing.varType.int, "number of threads")
0026 options.register("streams", 0, VarParsing.multiplicity.singleton, VarParsing.varType.int, "number of streams")
0027 options.register("modules", "TritonGraphProducer", VarParsing.multiplicity.list, VarParsing.varType.string, "list of modules to run (choices: {})".format(', '.join(models)))
0028 options.register("models","gat_test", VarParsing.multiplicity.list, VarParsing.varType.string, "list of models (same length as modules, or just 1 entry if all modules use same model)")
0029 options.register("mode","Async", VarParsing.multiplicity.singleton, VarParsing.varType.string, "mode for client (choices: {})".format(', '.join(allowed_modes)))
0030 options.register("verbose", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "enable verbose output")
0031 options.register("brief", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "briefer output for graph modules")
0032 options.register("fallbackName", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "name for fallback server")
0033 options.register("unittest", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "unit test mode: reduce input sizes")
0034 options.register("testother", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "also test gRPC communication if shared memory enabled, or vice versa")
0035 options.register("shm", True, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "enable shared memory")
0036 options.register("compression", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "enable I/O compression (choices: {})".format(', '.join(allowed_compression)))
0037 options.register("ssl", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "enable SSL authentication for server communication")
0038 options.register("device","auto", VarParsing.multiplicity.singleton, VarParsing.varType.string, "specify device for fallback server (choices: {})".format(', '.join(allowed_devices)))
0039 options.register("docker", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "use Docker for fallback server")
0040 options.register("tries", 0, VarParsing.multiplicity.singleton, VarParsing.varType.int, "number of retries for failed request")
0041 options.parseArguments()
0042
0043 if len(options.params)>0:
0044 with open(options.params,'r') as pfile:
0045 pdict = json.load(pfile)
0046 options.address = pdict["address"]
0047 options.port = int(pdict["port"])
0048 print("server = "+options.address+":"+str(options.port))
0049
0050
0051 if len(options.modules)!=len(options.models):
0052
0053 if len(options.models)==1: options.models = [options.models[0]]*(len(options.modules)-1)
0054 else: raise ValueError("Arguments for modules and models must have same length")
0055 for im,module in enumerate(options.modules):
0056 if module not in models:
0057 raise ValueError("Unknown module: {}".format(module))
0058 model = options.models[im]
0059 if model not in models[module]:
0060 raise ValueError("Unsupported model {} for module {}".format(model,module))
0061
0062
0063 if options.mode not in allowed_modes:
0064 raise ValueError("Unknown mode: {}".format(options.mode))
0065
0066
0067 if len(options.compression)>0 and options.compression not in allowed_compression:
0068 raise ValueError("Unknown compression setting: {}".format(options.compression))
0069
0070
0071 options.device = options.device.lower()
0072 if options.device not in allowed_devices:
0073 raise ValueError("Unknown device: {}".format(options.device))
0074
0075 from Configuration.ProcessModifiers.enableSonicTriton_cff import enableSonicTriton
0076 process = cms.Process('tritonTest',enableSonicTriton)
0077
0078 process.load("HeterogeneousCore.SonicTriton.TritonService_cff")
0079
0080 process.maxEvents = cms.untracked.PSet( input = cms.untracked.int32(options.maxEvents) )
0081
0082 process.source = cms.Source("EmptySource")
0083
0084 process.TritonService.verbose = options.verbose
0085 process.TritonService.fallback.verbose = options.verbose
0086 process.TritonService.fallback.useDocker = options.docker
0087 if len(options.fallbackName)>0:
0088 process.TritonService.fallback.instanceBaseName = options.fallbackName
0089 if options.device != "auto":
0090 process.TritonService.fallback.useGPU = options.device=="gpu"
0091 if len(options.address)>0:
0092 process.TritonService.servers.append(
0093 cms.PSet(
0094 name = cms.untracked.string(options.serverName),
0095 address = cms.untracked.string(options.address),
0096 port = cms.untracked.uint32(options.port),
0097 useSsl = cms.untracked.bool(options.ssl),
0098 rootCertificates = cms.untracked.string(""),
0099 privateKey = cms.untracked.string(""),
0100 certificateChain = cms.untracked.string(""),
0101 )
0102 )
0103
0104
0105 process.p = cms.Path()
0106
0107 modules = {
0108 "Producer": cms.EDProducer,
0109 "Filter": cms.EDFilter,
0110 "Analyzer": cms.EDAnalyzer,
0111 }
0112
0113 keepMsgs = ['TritonClient','TritonService']
0114
0115 for im,module in enumerate(options.modules):
0116 model = options.models[im]
0117 Module = [obj for name,obj in modules.items() if name in module][0]
0118 setattr(process, module,
0119 Module(module,
0120 Client = cms.PSet(
0121 mode = cms.string(options.mode),
0122 preferredServer = cms.untracked.string(""),
0123 timeout = cms.untracked.uint32(options.timeout),
0124 modelName = cms.string(model),
0125 modelVersion = cms.string(""),
0126 modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)),
0127 verbose = cms.untracked.bool(options.verbose),
0128 allowedTries = cms.untracked.uint32(options.tries),
0129 useSharedMemory = cms.untracked.bool(options.shm),
0130 compression = cms.untracked.string(options.compression),
0131 )
0132 )
0133 )
0134 processModule = getattr(process, module)
0135 if module=="TritonImageProducer":
0136 processModule.batchSize = cms.int32(1)
0137 processModule.topN = cms.uint32(5)
0138 processModule.imageList = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/{}_labels.txt".format(model,model.split('_')[0]))
0139 elif "TritonGraph" in module:
0140 if options.unittest:
0141
0142 processModule.nodeMin = cms.uint32(1)
0143 processModule.nodeMax = cms.uint32(10)
0144 processModule.edgeMin = cms.uint32(20)
0145 processModule.edgeMax = cms.uint32(40)
0146 else:
0147 processModule.nodeMin = cms.uint32(100)
0148 processModule.nodeMax = cms.uint32(4000)
0149 processModule.edgeMin = cms.uint32(8000)
0150 processModule.edgeMax = cms.uint32(15000)
0151 processModule.brief = cms.bool(options.brief)
0152 process.p += processModule
0153 keepMsgs.extend([module,module+':TritonClient'])
0154 if options.testother:
0155
0156 _module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM"
0157 setattr(process, _module2,
0158 processModule.clone(
0159 Client = dict(useSharedMemory = not processModule.Client.useSharedMemory)
0160 )
0161 )
0162 processModule2 = getattr(process, _module2)
0163 process.p += processModule2
0164 keepMsgs.extend([_module2,_module2+':TritonClient'])
0165
0166 process.load('FWCore/MessageService/MessageLogger_cfi')
0167 process.MessageLogger.cerr.FwkReport.reportEvery = 500
0168 for msg in keepMsgs:
0169 setattr(process.MessageLogger.cerr,msg,
0170 cms.untracked.PSet(
0171 limit = cms.untracked.int32(10000000),
0172 )
0173 )
0174
0175 if options.threads>0:
0176 process.options.numberOfThreads = options.threads
0177 process.options.numberOfStreams = options.streams
0178