Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:48

0001 import FWCore.ParameterSet.Config as cms
0002 import os, sys, json
0003 from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
0004 
0005 # module/model correspondence
0006 models = {
0007     "TritonImageProducer": ["inception_graphdef", "densenet_onnx"],
0008     "TritonGraphProducer": ["gat_test"],
0009     "TritonGraphFilter": ["gat_test"],
0010     "TritonGraphAnalyzer": ["gat_test"],
0011     "TritonIdentityProducer": ["ragged_io"],
0012 }
0013 
0014 # other choices
0015 allowed_modes = ["Async","PseudoAsync","Sync"]
0016 allowed_compression = ["none","deflate","gzip"]
0017 allowed_devices = ["auto","cpu","gpu"]
0018 
0019 parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
0020 parser.add_argument("--maxEvents", default=-1, type=int, help="Number of events to process (-1 for all)")
0021 parser.add_argument("--serverName", default="default", type=str, help="name for server (used internally)")
0022 parser.add_argument("--address", default="", type=str, help="server address")
0023 parser.add_argument("--port", default=8001, type=int, help="server port")
0024 parser.add_argument("--timeout", default=30, type=int, help="timeout for requests")
0025 parser.add_argument("--timeoutUnit", default="seconds", type=str, help="unit for timeout")
0026 parser.add_argument("--params", default="", type=str, help="json file containing server address/port")
0027 parser.add_argument("--threads", default=1, type=int, help="number of threads")
0028 parser.add_argument("--streams", default=0, type=int, help="number of streams")
0029 parser.add_argument("--modules", metavar=("MODULES"), default=["TritonGraphProducer"], nargs='+', type=str, choices=list(models), help="list of modules to run (choices: %(choices)s)")
0030 parser.add_argument("--models", default=["gat_test"], nargs='+', type=str, help="list of models (same length as modules, or just 1 entry if all modules use same model)")
0031 parser.add_argument("--mode", default="Async", type=str, choices=allowed_modes, help="mode for client")
0032 parser.add_argument("--verbose", default=False, action="store_true", help="enable all verbose output")
0033 parser.add_argument("--verboseClient", default=False, action="store_true", help="enable verbose output for clients")
0034 parser.add_argument("--verboseServer", default=False, action="store_true", help="enable verbose output for server")
0035 parser.add_argument("--verboseService", default=False, action="store_true", help="enable verbose output for TritonService")
0036 parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules")
0037 parser.add_argument("--fallbackName", default="", type=str, help="name for fallback server")
0038 parser.add_argument("--unittest", default=False, action="store_true", help="unit test mode: reduce input sizes")
0039 parser.add_argument("--testother", default=False, action="store_true", help="also test gRPC communication if shared memory enabled, or vice versa")
0040 parser.add_argument("--noShm", default=False, action="store_true", help="disable shared memory")
0041 parser.add_argument("--compression", default="", type=str, choices=allowed_compression, help="enable I/O compression")
0042 parser.add_argument("--ssl", default=False, action="store_true", help="enable SSL authentication for server communication")
0043 parser.add_argument("--device", default="auto", type=str.lower, choices=allowed_devices, help="specify device for fallback server")
0044 parser.add_argument("--docker", default=False, action="store_true", help="use Docker for fallback server")
0045 parser.add_argument("--tries", default=0, type=int, help="number of retries for failed request")
0046 options = parser.parse_args()
0047 
0048 if len(options.params)>0:
0049     with open(options.params,'r') as pfile:
0050         pdict = json.load(pfile)
0051     options.address = pdict["address"]
0052     options.port = int(pdict["port"])
0053     print("server = "+options.address+":"+str(options.port))
0054 
0055 # check models and modules
0056 if len(options.modules)!=len(options.models):
0057     # assigning to VarParsing.multiplicity.list actually appends to existing value(s)
0058     if len(options.models)==1: options.models = [options.models[0]]*(len(options.modules))
0059     else: raise ValueError("Arguments for modules and models must have same length")
0060 for im,module in enumerate(options.modules):
0061     model = options.models[im]
0062     if model not in models[module]:
0063         raise ValueError("Unsupported model {} for module {}".format(model,module))
0064 
0065 from Configuration.ProcessModifiers.enableSonicTriton_cff import enableSonicTriton
0066 process = cms.Process('tritonTest',enableSonicTriton)
0067 
0068 process.load("HeterogeneousCore.SonicTriton.TritonService_cff")
0069 
0070 process.maxEvents = cms.untracked.PSet( input = cms.untracked.int32(options.maxEvents) )
0071 
0072 process.source = cms.Source("EmptySource")
0073 
0074 process.TritonService.verbose = options.verbose or options.verboseService
0075 process.TritonService.fallback.verbose = options.verbose or options.verboseServer
0076 process.TritonService.fallback.useDocker = options.docker
0077 if len(options.fallbackName)>0:
0078     process.TritonService.fallback.instanceBaseName = options.fallbackName
0079 if options.device != "auto":
0080     process.TritonService.fallback.useGPU = options.device=="gpu"
0081 if len(options.address)>0:
0082     process.TritonService.servers.append(
0083         cms.PSet(
0084             name = cms.untracked.string(options.serverName),
0085             address = cms.untracked.string(options.address),
0086             port = cms.untracked.uint32(options.port),
0087             useSsl = cms.untracked.bool(options.ssl),
0088             rootCertificates = cms.untracked.string(""),
0089             privateKey = cms.untracked.string(""),
0090             certificateChain = cms.untracked.string(""),
0091         )
0092     )
0093 
0094 # Let it run
0095 process.p = cms.Path()
0096 
0097 modules = {
0098     "Producer": cms.EDProducer,
0099     "Filter": cms.EDFilter,
0100     "Analyzer": cms.EDAnalyzer,
0101 }
0102 
0103 keepMsgs = ['TritonClient','TritonService']
0104 
0105 for im,module in enumerate(options.modules):
0106     model = options.models[im]
0107     Module = [obj for name,obj in modules.items() if name in module][0]
0108     setattr(process, module,
0109         Module(module,
0110             Client = cms.PSet(
0111                 mode = cms.string(options.mode),
0112                 preferredServer = cms.untracked.string(""),
0113                 timeout = cms.untracked.uint32(options.timeout),
0114                 timeoutUnit = cms.untracked.string(options.timeoutUnit),
0115                 modelName = cms.string(model),
0116                 modelVersion = cms.string(""),
0117                 modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)),
0118                 verbose = cms.untracked.bool(options.verbose or options.verboseClient),
0119                 allowedTries = cms.untracked.uint32(options.tries),
0120                 useSharedMemory = cms.untracked.bool(not options.noShm),
0121                 compression = cms.untracked.string(options.compression),
0122             )
0123         )
0124     )
0125     processModule = getattr(process, module)
0126     if module=="TritonImageProducer":
0127         processModule.batchSize = cms.int32(1)
0128         processModule.topN = cms.uint32(5)
0129         processModule.imageList = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/{}_labels.txt".format(model,model.split('_')[0]))
0130     elif "TritonGraph" in module:
0131         if options.unittest:
0132             # reduce input size for unit test
0133             processModule.nodeMin = cms.uint32(1)
0134             processModule.nodeMax = cms.uint32(10)
0135             processModule.edgeMin = cms.uint32(20)
0136             processModule.edgeMax = cms.uint32(40)
0137         else:
0138             processModule.nodeMin = cms.uint32(100)
0139             processModule.nodeMax = cms.uint32(4000)
0140             processModule.edgeMin = cms.uint32(8000)
0141             processModule.edgeMax = cms.uint32(15000)
0142         processModule.brief = cms.bool(options.brief)
0143     process.p += processModule
0144     keepMsgs.extend([module,module+':TritonClient'])
0145     if options.testother:
0146         # clone modules to test both gRPC and shared memory
0147         _module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM"
0148         setattr(process, _module2,
0149             processModule.clone(
0150                 Client = dict(useSharedMemory = not processModule.Client.useSharedMemory)
0151             )
0152         )
0153         processModule2 = getattr(process, _module2)
0154         process.p += processModule2
0155         keepMsgs.extend([_module2,_module2+':TritonClient'])
0156 
0157 process.load('FWCore/MessageService/MessageLogger_cfi')
0158 process.MessageLogger.cerr.FwkReport.reportEvery = 500
0159 for msg in keepMsgs:
0160     setattr(process.MessageLogger.cerr,msg,
0161         cms.untracked.PSet(
0162             limit = cms.untracked.int32(10000000),
0163         )
0164     )
0165 
0166 if options.threads>0:
0167     process.options.numberOfThreads = options.threads
0168     process.options.numberOfStreams = options.streams
0169