Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-04-11 03:31:17

0001 import FWCore.ParameterSet.Config as cms
0002 
0003 def getDefaultClientPSet():
0004     from HeterogeneousCore.SonicTriton.TritonGraphAnalyzer import TritonGraphAnalyzer
0005     temp = TritonGraphAnalyzer()
0006     return temp.Client
0007 
0008 def getParser():
0009     allowed_compression = ["none","deflate","gzip"]
0010     allowed_devices = ["auto","cpu","gpu"]
0011     allowed_containers = ["apptainer","docker","podman","podman-hpc"]
0012 
0013     from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
0014     parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
0015     parser.add_argument("--maxEvents", default=-1, type=int, help="Number of events to process (-1 for all)")
0016     parser.add_argument("--serverName", default="default", type=str, help="name for server (used internally)")
0017     parser.add_argument("--address", default="", type=str, help="server address")
0018     parser.add_argument("--port", default=8001, type=int, help="server port")
0019     parser.add_argument("--timeout", default=30, type=int, help="timeout for requests")
0020     parser.add_argument("--timeoutUnit", default="seconds", type=str, help="unit for timeout")
0021     parser.add_argument("--params", default="", type=str, help="json file containing server address/port")
0022     parser.add_argument("--threads", default=1, type=int, help="number of threads")
0023     parser.add_argument("--streams", default=0, type=int, help="number of streams")
0024     parser.add_argument("--verbose", default=False, action="store_true", help="enable all verbose output")
0025     parser.add_argument("--verboseClient", default=False, action="store_true", help="enable verbose output for clients")
0026     parser.add_argument("--verboseServer", default=False, action="store_true", help="enable verbose output for server")
0027     parser.add_argument("--verboseService", default=False, action="store_true", help="enable verbose output for TritonService")
0028     parser.add_argument("--verboseDiscovery", default=False, action="store_true", help="enable verbose output just for server discovery in TritonService")
0029     parser.add_argument("--noShm", default=False, action="store_true", help="disable shared memory")
0030     parser.add_argument("--compression", default="", type=str, choices=allowed_compression, help="enable I/O compression")
0031     parser.add_argument("--ssl", default=False, action="store_true", help="enable SSL authentication for server communication")
0032     parser.add_argument("--tries", default=0, type=int, help="number of retries for failed request")
0033     parser.add_argument("--device", default="auto", type=str.lower, choices=allowed_devices, help="specify device for fallback server")
0034     parser.add_argument("--container", default="apptainer", type=str.lower, choices=allowed_containers, help="specify container for fallback server")
0035     parser.add_argument("--fallbackName", default="", type=str, help="name for fallback server")
0036     parser.add_argument("--imageName", default="", type=str, help="container image name for fallback server")
0037     parser.add_argument("--tempDir", default="", type=str, help="temp directory for fallback server")
0038 
0039     return parser
0040 
0041 def getOptions(parser, verbose=False):
0042     options = parser.parse_args()
0043 
0044     if len(options.params)>0:
0045         with open(options.params,'r') as pfile:
0046             pdict = json.load(pfile)
0047         options.address = pdict["address"]
0048         options.port = int(pdict["port"])
0049         if verbose: print("server = "+options.address+":"+str(options.port))
0050 
0051     return options
0052 
0053 def applyOptions(process, options, applyToModules=False):
0054     process.maxEvents.input = cms.untracked.int32(options.maxEvents)
0055 
0056     if options.threads>0:
0057         process.options.numberOfThreads = options.threads
0058         process.options.numberOfStreams = options.streams
0059 
0060     if options.verbose:
0061         configureLoggingAll(process)
0062     else:
0063         configureLogging(process,
0064             client=options.verboseClient,
0065             server=options.verboseServer,
0066             service=options.verboseService,
0067             discovery=options.verboseDiscovery
0068         )
0069 
0070     if hasattr(process,'TritonService'):
0071         process.TritonService.fallback.container = options.container
0072         process.TritonService.fallback.imageName = options.imageName
0073         process.TritonService.fallback.tempDir = options.tempDir
0074         process.TritonService.fallback.device = options.device
0075         if len(options.fallbackName)>0:
0076             process.TritonService.fallback.instanceBaseName = options.fallbackName
0077         if len(options.address)>0:
0078             process.TritonService.servers.append(
0079                 dict(
0080                     name = options.serverName,
0081                     address = options.address,
0082                     port = options.port,
0083                     useSsl = options.ssl,
0084                 )
0085             )
0086 
0087     if applyToModules:
0088         process = configureModules(process, **getClientOptions(options))
0089 
0090     return process
0091 
0092 def getClientOptions(options):
0093     return dict(
0094         compression = cms.untracked.string(options.compression),
0095         useSharedMemory = cms.untracked.bool(not options.noShm),
0096         timeout = cms.untracked.uint32(options.timeout),
0097         timeoutUnit = cms.untracked.string(options.timeoutUnit),
0098         allowedTries = cms.untracked.uint32(options.tries),
0099     )
0100 
0101 def applyClientOptions(client, options):
0102     return configureClient(client, **getClientOptions(options))
0103 
0104 def configureModules(process, modules=None, returnConfigured=False, **kwargs):
0105     if modules is None:
0106         modules = {}
0107         modules.update(process.producers_())
0108         modules.update(process.filters_())
0109         modules.update(process.analyzers_())
0110     configured = []
0111     for pname,producer in modules.items():
0112         if hasattr(producer,'Client'):
0113             producer.Client = configureClient(producer.Client, **kwargs)
0114             configured.append(pname)
0115     if returnConfigured:
0116         return process, configured
0117     else:
0118         return process
0119 
0120 def configureClient(client, **kwargs):
0121     client.update_(kwargs)
0122     return client
0123 
0124 def configureLogging(process, client=False, server=False, service=False, discovery=False):
0125     if not any([client, server, service, discovery]):
0126         return
0127 
0128     keepMsgs = []
0129     if discovery:
0130         keepMsgs.append('TritonDiscovery')
0131     if client:
0132         keepMsgs.append('TritonClient')
0133     if service:
0134         keepMsgs.append('TritonService')
0135 
0136     if hasattr(process,'TritonService'):
0137         process.TritonService.verbose = service or discovery
0138         process.TritonService.fallback.verbose = server
0139     if client:
0140         process, configured = configureModules(process, returnConfigured=True, verbose = True)
0141         for module in configured:
0142             keepMsgs.extend([module, module+':TritonClient'])
0143 
0144     process.MessageLogger.cerr.FwkReport.reportEvery = 500
0145     for msg in keepMsgs:
0146         setattr(process.MessageLogger.cerr, msg,
0147             dict(
0148                 limit = 10000000,
0149             )
0150         )
0151 
0152     return process
0153 
0154 # dedicated functions for cmsDriver customization
0155 
0156 def configureLoggingClient(process):
0157     return configureLogging(process, client=True)
0158 
0159 def configureLoggingServer(process):
0160     return configureLogging(process, server=True)
0161 
0162 def configureLoggingService(process):
0163     return configureLogging(process, service=True)
0164 
0165 def configureLoggingDiscovery(process):
0166     return configureLogging(process, discovery=True)
0167 
0168 def configureLoggingAll(process):
0169     return configureLogging(process, client=True, server=True, service=True, discovery=True)