Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-08-06 22:36:41

0001 import FWCore.ParameterSet.Config as cms
0002 import os, sys, json
0003 from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
0004 
0005 # module/model correspondence
0006 models = {
0007     "TritonImageProducer": ["inception_graphdef", "densenet_onnx"],
0008     "TritonGraphProducer": ["gat_test"],
0009     "TritonGraphFilter": ["gat_test"],
0010     "TritonGraphAnalyzer": ["gat_test"],
0011     "TritonIdentityProducer": ["ragged_io"],
0012 }
0013 
0014 # other choices
0015 allowed_modes = ["Async","PseudoAsync","Sync"]
0016 allowed_compression = ["none","deflate","gzip"]
0017 allowed_devices = ["auto","cpu","gpu"]
0018 allowed_containers = ["apptainer","docker","podman","podman-hpc"]
0019 
0020 parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
0021 parser.add_argument("--maxEvents", default=-1, type=int, help="Number of events to process (-1 for all)")
0022 parser.add_argument("--serverName", default="default", type=str, help="name for server (used internally)")
0023 parser.add_argument("--address", default="", type=str, help="server address")
0024 parser.add_argument("--port", default=8001, type=int, help="server port")
0025 parser.add_argument("--timeout", default=30, type=int, help="timeout for requests")
0026 parser.add_argument("--timeoutUnit", default="seconds", type=str, help="unit for timeout")
0027 parser.add_argument("--params", default="", type=str, help="json file containing server address/port")
0028 parser.add_argument("--threads", default=1, type=int, help="number of threads")
0029 parser.add_argument("--streams", default=0, type=int, help="number of streams")
0030 parser.add_argument("--modules", metavar=("MODULES"), default=["TritonGraphProducer"], nargs='+', type=str, choices=list(models), help="list of modules to run (choices: %(choices)s)")
0031 parser.add_argument("--models", default=["gat_test"], nargs='+', type=str, help="list of models (same length as modules, or just 1 entry if all modules use same model)")
0032 parser.add_argument("--mode", default="Async", type=str, choices=allowed_modes, help="mode for client")
0033 parser.add_argument("--verbose", default=False, action="store_true", help="enable all verbose output")
0034 parser.add_argument("--verboseClient", default=False, action="store_true", help="enable verbose output for clients")
0035 parser.add_argument("--verboseServer", default=False, action="store_true", help="enable verbose output for server")
0036 parser.add_argument("--verboseService", default=False, action="store_true", help="enable verbose output for TritonService")
0037 parser.add_argument("--verboseDiscovery", default=False, action="store_true", help="enable verbose output just for server discovery in TritonService")
0038 parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules")
0039 parser.add_argument("--fallbackName", default="", type=str, help="name for fallback server")
0040 parser.add_argument("--unittest", default=False, action="store_true", help="unit test mode: reduce input sizes")
0041 parser.add_argument("--testother", default=False, action="store_true", help="also test gRPC communication if shared memory enabled, or vice versa")
0042 parser.add_argument("--noShm", default=False, action="store_true", help="disable shared memory")
0043 parser.add_argument("--compression", default="", type=str, choices=allowed_compression, help="enable I/O compression")
0044 parser.add_argument("--ssl", default=False, action="store_true", help="enable SSL authentication for server communication")
0045 parser.add_argument("--device", default="auto", type=str.lower, choices=allowed_devices, help="specify device for fallback server")
0046 parser.add_argument("--container", default="apptainer", type=str.lower, choices=allowed_containers, help="specify container for fallback server")
0047 parser.add_argument("--tries", default=0, type=int, help="number of retries for failed request")
0048 options = parser.parse_args()
0049 
0050 if len(options.params)>0:
0051     with open(options.params,'r') as pfile:
0052         pdict = json.load(pfile)
0053     options.address = pdict["address"]
0054     options.port = int(pdict["port"])
0055     print("server = "+options.address+":"+str(options.port))
0056 
0057 # check models and modules
0058 if len(options.modules)!=len(options.models):
0059     # assigning to VarParsing.multiplicity.list actually appends to existing value(s)
0060     if len(options.models)==1: options.models = [options.models[0]]*(len(options.modules))
0061     else: raise ValueError("Arguments for modules and models must have same length")
0062 for im,module in enumerate(options.modules):
0063     model = options.models[im]
0064     if model not in models[module]:
0065         raise ValueError("Unsupported model {} for module {}".format(model,module))
0066 
0067 from Configuration.ProcessModifiers.enableSonicTriton_cff import enableSonicTriton
0068 process = cms.Process('tritonTest',enableSonicTriton)
0069 
0070 process.load("HeterogeneousCore.SonicTriton.TritonService_cff")
0071 
0072 process.maxEvents = cms.untracked.PSet( input = cms.untracked.int32(options.maxEvents) )
0073 
0074 process.source = cms.Source("EmptySource")
0075 
0076 process.TritonService.verbose = options.verbose or options.verboseService or options.verboseDiscovery
0077 process.TritonService.fallback.verbose = options.verbose or options.verboseServer
0078 process.TritonService.fallback.container = options.container
0079 process.TritonService.fallback.device = options.device
0080 if len(options.fallbackName)>0:
0081     process.TritonService.fallback.instanceBaseName = options.fallbackName
0082 if len(options.address)>0:
0083     process.TritonService.servers.append(
0084         cms.PSet(
0085             name = cms.untracked.string(options.serverName),
0086             address = cms.untracked.string(options.address),
0087             port = cms.untracked.uint32(options.port),
0088             useSsl = cms.untracked.bool(options.ssl),
0089             rootCertificates = cms.untracked.string(""),
0090             privateKey = cms.untracked.string(""),
0091             certificateChain = cms.untracked.string(""),
0092         )
0093     )
0094 
0095 # Let it run
0096 process.p = cms.Path()
0097 
0098 modules = {
0099     "Producer": cms.EDProducer,
0100     "Filter": cms.EDFilter,
0101     "Analyzer": cms.EDAnalyzer,
0102 }
0103 
0104 keepMsgs = []
0105 if options.verbose or options.verboseDiscovery:
0106     keepMsgs.append('TritonDiscovery')
0107 if options.verbose or options.verboseClient:
0108     keepMsgs.append('TritonClient')
0109 if options.verbose or options.verboseService:
0110     keepMsgs.append('TritonService')
0111 
0112 for im,module in enumerate(options.modules):
0113     model = options.models[im]
0114     Module = [obj for name,obj in modules.items() if name in module][0]
0115     setattr(process, module,
0116         Module(module,
0117             Client = cms.PSet(
0118                 mode = cms.string(options.mode),
0119                 preferredServer = cms.untracked.string(""),
0120                 timeout = cms.untracked.uint32(options.timeout),
0121                 timeoutUnit = cms.untracked.string(options.timeoutUnit),
0122                 modelName = cms.string(model),
0123                 modelVersion = cms.string(""),
0124                 modelConfigPath = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/config.pbtxt".format(model)),
0125                 verbose = cms.untracked.bool(options.verbose or options.verboseClient),
0126                 allowedTries = cms.untracked.uint32(options.tries),
0127                 useSharedMemory = cms.untracked.bool(not options.noShm),
0128                 compression = cms.untracked.string(options.compression),
0129             )
0130         )
0131     )
0132     processModule = getattr(process, module)
0133     if module=="TritonImageProducer":
0134         processModule.batchSize = cms.int32(1)
0135         processModule.topN = cms.uint32(5)
0136         processModule.imageList = cms.FileInPath("HeterogeneousCore/SonicTriton/data/models/{}/{}_labels.txt".format(model,model.split('_')[0]))
0137     elif "TritonGraph" in module:
0138         if options.unittest:
0139             # reduce input size for unit test
0140             processModule.nodeMin = cms.uint32(1)
0141             processModule.nodeMax = cms.uint32(10)
0142             processModule.edgeMin = cms.uint32(20)
0143             processModule.edgeMax = cms.uint32(40)
0144         else:
0145             processModule.nodeMin = cms.uint32(100)
0146             processModule.nodeMax = cms.uint32(4000)
0147             processModule.edgeMin = cms.uint32(8000)
0148             processModule.edgeMax = cms.uint32(15000)
0149         processModule.brief = cms.bool(options.brief)
0150     process.p += processModule
0151     if options.verbose or options.verboseClient:
0152         keepMsgs.extend([module,module+':TritonClient'])
0153     if options.testother:
0154         # clone modules to test both gRPC and shared memory
0155         _module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM"
0156         setattr(process, _module2,
0157             processModule.clone(
0158                 Client = dict(useSharedMemory = not processModule.Client.useSharedMemory)
0159             )
0160         )
0161         processModule2 = getattr(process, _module2)
0162         process.p += processModule2
0163         if options.verbose or options.verboseClient:
0164             keepMsgs.extend([_module2,_module2+':TritonClient'])
0165 
0166 process.load('FWCore/MessageService/MessageLogger_cfi')
0167 process.MessageLogger.cerr.FwkReport.reportEvery = 500
0168 for msg in keepMsgs:
0169     setattr(process.MessageLogger.cerr,msg,
0170         cms.untracked.PSet(
0171             limit = cms.untracked.int32(10000000),
0172         )
0173     )
0174 
0175 if options.threads>0:
0176     process.options.numberOfThreads = options.threads
0177     process.options.numberOfStreams = options.streams
0178