Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-02-25 02:40:55

0001 from FWCore.ParameterSet.VarParsing import VarParsing
0002 import FWCore.ParameterSet.Config as cms
0003 import os, sys, json
0004 
0005 # module/model correspondence
0006 models = {
0007     "TritonImageProducer": ["inception_graphdef", "densenet_onnx"],
0008     "TritonGraphProducer": ["gat_test"],
0009     "TritonGraphFilter": ["gat_test"],
0010     "TritonGraphAnalyzer": ["gat_test"],
0011 }
0012 
0013 # other choices
0014 allowed_modes = ["Async","PseudoAsync","Sync"]
0015 allowed_compression = ["none","deflate","gzip"]
0016 allowed_devices = ["auto","cpu","gpu"]
0017 
0018 options = VarParsing()
0019 options.register("maxEvents", -1, VarParsing.multiplicity.singleton, VarParsing.varType.int, "Number of events to process (-1 for all)")
0020 options.register("serverName", "default", VarParsing.multiplicity.singleton, VarParsing.varType.string, "name for server (used internally)")
0021 options.register("address", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "server address")
0022 options.register("port", 8001, VarParsing.multiplicity.singleton, VarParsing.varType.int, "server port")
0023 options.register("timeout", 30, VarParsing.multiplicity.singleton, VarParsing.varType.int, "timeout for requests")
0024 options.register("params", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "json file containing server address/port")
0025 options.register("threads", 1, VarParsing.multiplicity.singleton, VarParsing.varType.int, "number of threads")
0026 options.register("streams", 0, VarParsing.multiplicity.singleton, VarParsing.varType.int, "number of streams")
0027 options.register("modules", "TritonGraphProducer", VarParsing.multiplicity.list, VarParsing.varType.string, "list of modules to run (choices: {})".format(', '.join(models)))
0028 options.register("models","gat_test", VarParsing.multiplicity.list, VarParsing.varType.string, "list of models (same length as modules, or just 1 entry if all modules use same model)")
0029 options.register("mode","Async", VarParsing.multiplicity.singleton, VarParsing.varType.string, "mode for client (choices: {})".format(', '.join(allowed_modes)))
0030 options.register("verbose", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "enable verbose output")
0031 options.register("brief", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "briefer output for graph modules")
0032 options.register("fallbackName", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "name for fallback server")
0033 options.register("unittest", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "unit test mode: reduce input sizes")
0034 options.register("testother", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "also test gRPC communication if shared memory enabled, or vice versa")
0035 options.register("shm", True, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "enable shared memory")
0036 options.register("compression", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "enable I/O compression (choices: {})".format(', '.join(allowed_compression)))
0037 options.register("ssl", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "enable SSL authentication for server communication")
0038 options.register("device","auto", VarParsing.multiplicity.singleton, VarParsing.varType.string, "specify device for fallback server (choices: {})".format(', '.join(allowed_devices)))
0039 options.register("docker", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "use Docker for fallback server")
0040 options.register("tries", 0, VarParsing.multiplicity.singleton, VarParsing.varType.int, "number of retries for failed request")
0041 options.parseArguments()
0042 
0043 if len(options.params)>0:
0044     with open(options.params,'r') as pfile:
0045         pdict = json.load(pfile)
0046     options.address = pdict["address"]
0047     options.port = int(pdict["port"])
0048     print("server = "+options.address+":"+str(options.port))
0049 
0050 # check models and modules
0051 if len(options.modules)!=len(options.models):
0052     # assigning to VarParsing.multiplicity.list actually appends to existing value(s)
0053     if len(options.models)==1: options.models = [options.models[0]]*(len(options.modules)-1)
0054     else: raise ValueError("Arguments for modules and models must have same length")
0055 for im,module in enumerate(options.modules):
0056     if module not in models:
0057         raise ValueError("Unknown module: {}".format(module))
0058     model = options.models[im]
0059     if model not in models[module]:
0060         raise ValueError("Unsupported model {} for module {}".format(model,module))
0061 
0062 # check modes
0063 if options.mode not in allowed_modes:
0064     raise ValueError("Unknown mode: {}".format(options.mode))
0065 
0066 # check compression
0067 if len(options.compression)>0 and options.compression not in allowed_compression:
0068     raise ValueError("Unknown compression setting: {}".format(options.compression))
0069 
0070 # check devices
0071 options.device = options.device.lower()
0072 if options.device not in allowed_devices:
0073     raise ValueError("Unknown device: {}".format(options.device))
0074 
0075 from Configuration.ProcessModifiers.enableSonicTriton_cff import enableSonicTriton
0076 process = cms.Process('tritonTest',enableSonicTriton)
0077 
0078 process.load("HeterogeneousCore.SonicTriton.TritonService_cff")
0079 
0080 process.maxEvents = cms.untracked.PSet( input = cms.untracked.int32(options.maxEvents) )
0081 
0082 process.source = cms.Source("EmptySource")
0083 
0084 process.TritonService.verbose = options.verbose
0085 process.TritonService.fallback.verbose = options.verbose
0086 process.TritonService.fallback.useDocker = options.docker
0087 if len(options.fallbackName)>0:
0088     process.TritonService.fallback.instanceBaseName = options.fallbackName
0089 if options.device != "auto":
0090     process.TritonService.fallback.useGPU = options.device=="gpu"
0091 if len(options.address)>0:
0092     process.TritonService.servers.append(
0093         cms.PSet(
0094             name = cms.untracked.string(options.serverName),
0095             address = cms.untracked.string(options.address),
0096             port = cms.untracked.uint32(options.port),
0097             useSsl = cms.untracked.bool(options.ssl),
0098             rootCertificates = cms.untracked.string(""),
0099             privateKey = cms.untracked.string(""),
0100             certificateChain = cms.untracked.string(""),
0101         )
0102     )
0103 
0104 # Let it run
0105 process.p = cms.Path()
0106 
0107 modules = {
0108     "Producer": cms.EDProducer,
0109     "Filter": cms.EDFilter,
0110     "Analyzer": cms.EDAnalyzer,
0111 }
0112 
0113 keepMsgs = ['TritonClient','TritonService']
0114 
0115 for im,module in enumerate(options.modules):
0116     model = options.models[im]
0117     Module = [obj for name,obj in modules.items() if name in module][0]
0118     setattr(process, module,
0119         Module(module,
0120             Client = cms.PSet(
0121                 mode = cms.string(options.mode),
0122                 preferredServer = cms.untracked.string(""),
0123                 timeout = cms.untracked.uint32(options.timeout),
0124                 modelName = cms.string(model),
0125                 modelVersion = cms.string(""),
0126                 modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)),
0127                 verbose = cms.untracked.bool(options.verbose),
0128                 allowedTries = cms.untracked.uint32(options.tries),
0129                 useSharedMemory = cms.untracked.bool(options.shm),
0130                 compression = cms.untracked.string(options.compression),
0131             )
0132         )
0133     )
0134     processModule = getattr(process, module)
0135     if module=="TritonImageProducer":
0136         processModule.batchSize = cms.int32(1)
0137         processModule.topN = cms.uint32(5)
0138         processModule.imageList = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/{}_labels.txt".format(model,model.split('_')[0]))
0139     elif "TritonGraph" in module:
0140         if options.unittest:
0141             # reduce input size for unit test
0142             processModule.nodeMin = cms.uint32(1)
0143             processModule.nodeMax = cms.uint32(10)
0144             processModule.edgeMin = cms.uint32(20)
0145             processModule.edgeMax = cms.uint32(40)
0146         else:
0147             processModule.nodeMin = cms.uint32(100)
0148             processModule.nodeMax = cms.uint32(4000)
0149             processModule.edgeMin = cms.uint32(8000)
0150             processModule.edgeMax = cms.uint32(15000)
0151         processModule.brief = cms.bool(options.brief)
0152     process.p += processModule
0153     keepMsgs.extend([module,module+':TritonClient'])
0154     if options.testother:
0155         # clone modules to test both gRPC and shared memory
0156         _module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM"
0157         setattr(process, _module2,
0158             processModule.clone(
0159                 Client = dict(useSharedMemory = not processModule.Client.useSharedMemory)
0160             )
0161         )
0162         processModule2 = getattr(process, _module2)
0163         process.p += processModule2
0164         keepMsgs.extend([_module2,_module2+':TritonClient'])
0165 
0166 process.load('FWCore/MessageService/MessageLogger_cfi')
0167 process.MessageLogger.cerr.FwkReport.reportEvery = 500
0168 for msg in keepMsgs:
0169     setattr(process.MessageLogger.cerr,msg,
0170         cms.untracked.PSet(
0171             limit = cms.untracked.int32(10000000),
0172         )
0173     )
0174 
0175 if options.threads>0:
0176     process.options.numberOfThreads = options.threads
0177     process.options.numberOfStreams = options.streams
0178