Back to home page

Project CMSSW displayed by LXR

 
 

    


Warning, /HeterogeneousCore/SonicTriton/scripts/cmsTritonConfigTool is written in an unsupported language. File is not indexed.

0001 #!/usr/bin/env python3
0002 
0003 import os, sys, json, pathlib, shutil
0004 from collections import OrderedDict
0005 from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter, RawDescriptionHelpFormatter, Action, Namespace
0006 from enum import Enum
0007 from google.protobuf import text_format, json_format, message, descriptor
0008 from google.protobuf.internal import type_checkers
0009 from tritonclient import grpc
0010 
0011 # convenience definition
0012 # (from ConfigArgParse)
0013 class ArgumentDefaultsRawHelpFormatter(
0014     ArgumentDefaultsHelpFormatter,
0015     RawTextHelpFormatter,
0016     RawDescriptionHelpFormatter):
0017     """HelpFormatter that adds default values AND doesn't do line-wrapping"""
0018 pass
0019 
0020 class DictAction(Action):
0021     val_type = None
0022     def __call__(self, parser, namespace, values, option_string=None):
0023         if self.val_type is None:
0024             self.val_type = self.type
0025         result = {}
0026         if len(values)%2!=0:
0027             parser.error("{} args must come in pairs".format(self.dest))
0028         for i in range(0, len(values), 2):
0029             result[values[i]] = self.val_type(values[i+1])
0030         setattr(namespace, self.dest, result)
0031 
0032 class TritonChecksumStatus(Enum):
0033     CORRECT = 0
0034     MISSING = 1
0035     INCORRECT = 2
0036 
0037 message_classes = {cls.__name__ : cls for cls in message.Message.__subclasses__()}
0038 
0039 _FieldDescriptor = descriptor.FieldDescriptor
0040 cpp_to_python = {
0041     _FieldDescriptor.CPPTYPE_INT32: int,
0042     _FieldDescriptor.CPPTYPE_INT64: int,
0043     _FieldDescriptor.CPPTYPE_UINT32: int,
0044     _FieldDescriptor.CPPTYPE_UINT64: int,
0045     _FieldDescriptor.CPPTYPE_DOUBLE: float,
0046     _FieldDescriptor.CPPTYPE_FLOAT: float,
0047     _FieldDescriptor.CPPTYPE_BOOL: bool,
0048     _FieldDescriptor.CPPTYPE_STRING: str,
0049 }
0050 checker_to_type = {val.__class__:cpp_to_python[key] for key,val in type_checkers._VALUE_CHECKERS.items()}
0051 # for some reason, this one is not in the map
0052 checker_to_type[type_checkers.UnicodeValueChecker] = str
0053 
0054 kind_to_int = {v.name:v.number for v in grpc.model_config_pb2._MODELINSTANCEGROUP_KIND.values}
0055 thread_control_parameters = {
0056     "onnx": ["intra_op_thread_count", "inter_op_thread_count"],
0057     "tensorflow": ["TF_NUM_INTRA_THREADS", "TF_NUM_INTER_THREADS", "TF_USE_PER_SESSION_THREADS"],
0058 }
0059 
0060 def get_type(obj):
0061     obj_type = obj.__class__.__name__
0062     entry_type = None
0063     entry_class = None
0064     if obj_type=="RepeatedCompositeFieldContainer" or obj_type=="MessageMap":
0065         entry_type = obj._message_descriptor.name
0066         entry_class = message_classes[entry_type]
0067     elif obj_type=="RepeatedScalarFieldContainer":
0068         entry_class = checker_to_type[obj._type_checker.__class__]
0069         entry_type = entry_class.__name__
0070     elif obj_type=="ScalarMap":
0071         entry_class = obj.GetEntryClass()().value.__class__
0072         entry_type = entry_class.__name__
0073     return {
0074         "class": obj.__class__,
0075         "type": obj_type+("<"+entry_type+">" if entry_type is not None else ""),
0076         "entry_class": entry_class,
0077         "entry_type": entry_type,
0078     }
0079 
0080 def get_fields(obj, name, level=0, verbose=False):
0081     prefix = '    '*level
0082     obj_info = {"name": name, "fields": []}
0083     obj_info.update(get_type(obj))
0084     if verbose: print(prefix+obj_info["type"],name)
0085     field_obj = None
0086     if hasattr(obj, "DESCRIPTOR"):
0087         field_obj = obj
0088     elif obj_info["entry_class"] is not None and hasattr(obj_info["entry_class"], "DESCRIPTOR"):
0089         field_obj = obj_info["entry_class"]()
0090     field_list = []
0091     if field_obj is not None:
0092         field_list = [f.name for f in field_obj.DESCRIPTOR.fields]
0093     for field in field_list:
0094         obj_info["fields"].append(get_fields(getattr(field_obj,field),field,level+1,verbose))
0095     return obj_info
0096 
0097 def get_model_info():
0098     return get_fields(grpc.model_config_pb2.ModelConfig(), "ModelConfig")
0099 
0100 def msg_json(val, defaults=False):
0101     return json_format.MessageToJson(val, preserving_proto_field_name=True, including_default_value_fields=defaults, indent=0).replace(",\n",", ").replace("\n","")
0102 
0103 def print_fields(obj, info, level=0, json=False, defaults=False):
0104     def print_subfields(obj,level):
0105         fields = obj.DESCRIPTOR.fields if defaults else [f[0] for f in obj.ListFields()]
0106         for field in fields:
0107             print_fields(getattr(obj,field.name), next(f for f in info["fields"] if f["name"]==field.name), level=level, json=json, defaults=defaults)
0108 
0109     prefix = '    '
0110     print(prefix*level+info["type"],info["name"])
0111     if hasattr(obj, "DESCRIPTOR"):
0112         if json and level>0:
0113             print(prefix*(level+1)+msg_json(obj, defaults))
0114         else:
0115             print_subfields(obj,level+1)
0116     elif info["type"].startswith("RepeatedCompositeFieldContainer"):
0117         if json:
0118             print(prefix*(level+1)+str([msg_json(val, defaults) for val in obj]))
0119         else:
0120             for ientry,entry in enumerate(obj):
0121                 print(prefix*(level+1)+"{}: ".format(ientry))
0122                 print_subfields(entry,level+2)
0123     elif info["type"].startswith("MessageMap"):
0124         if json:
0125             print(prefix*(level+1)+str({key:msg_json(val, defaults) for key,val in obj.items()}))
0126         else:
0127             for key,val in obj.items():
0128                 print(prefix*(level+1)+"{}: ".format(key))
0129                 print_subfields(val,level+2)
0130     else:
0131         print(prefix*(level+1)+str(obj))
0132 
0133 def edit_builtin(model,dest,val):
0134     setattr(model,dest,val)
0135 
0136 def edit_scalar_list(model,dest,val):
0137     item = getattr(model,dest)
0138     item.clear()
0139     item.extend(val)
0140 
0141 def edit_scalar_map(model,dest,val):
0142     item = getattr(model,dest)
0143     item.clear()
0144     item.update(val)
0145 
0146 def edit_msg(model,dest,val):
0147     item = getattr(model,dest)
0148     json_format.ParseDict(val,item)
0149 
0150 def edit_msg_list(model,dest,val):
0151     item = getattr(model,dest)
0152     item.clear()
0153     for v in vals:
0154         m = item.add()
0155         json_format.ParseDict(v,m)
0156 
0157 def edit_msg_map(model,dest,val):
0158     item = getattr(model,dest)
0159     item.clear()
0160     for k,v in vals.items():
0161         m = item.get_or_create(k)
0162         json_format.ParseDict(v,m)
0163 
0164 def add_edit_args(parser, model_info):
0165     group = parser.add_argument_group("fields", description="ModelConfig fields to edit")
0166     dests = {}
0167     for field in model_info["fields"]:
0168         argname = "--{}".format(field["name"].replace("_","-"))
0169         val_type = None
0170         editor = None
0171         if field["class"].__module__=="builtins":
0172             kwargs = dict(type=field["class"])
0173             editor = edit_builtin
0174         elif field["type"].startswith("RepeatedScalarFieldContainer"):
0175             kwargs = dict(type=field["entry_class"], nargs='*')
0176             editor = edit_scalar_list
0177         elif field["type"].startswith("ScalarMap"):
0178             kwargs = dict(type=str, nargs='*', metavar="key value", action=DictAction)
0179             val_type = field["entry_class"]
0180             editor = edit_scalar_map
0181         elif field["type"].startswith("RepeatedCompositeFieldContainer"):
0182             kwargs = dict(type=json.loads, nargs='*',
0183                 help="provide {} values in json format".format(field["entry_type"])
0184             )
0185             editor = edit_msg_list
0186         elif field["type"].startswith("MessageMap"):
0187             kwargs = dict(type=str, nargs='*', metavar="key value", action=DictAction,
0188                 help="provide {} values in json format".format(field["entry_type"])
0189             )
0190             editor = edit_msg_map
0191             val_type = json.loads
0192         else:
0193             kwargs = dict(type=json.loads,
0194                 help="provide {} values in json format".format(field["type"])
0195             )
0196             edit = edit_msg
0197         action = group.add_argument(argname, **kwargs)
0198         if val_type is not None: action.val_type = val_type
0199         dests[action.dest] = editor
0200     return parser, dests
0201 
0202 def get_checksum(filename, chunksize=4096):
0203     import hashlib
0204     with open(filename, 'rb') as f:
0205         file_hash = hashlib.md5()
0206         while chunk := f.read(chunksize):
0207             file_hash.update(chunk)
0208     return file_hash.hexdigest()
0209 
0210 def get_checksum_update_cmd(force=False):
0211     extra_args = ["--update"]
0212     if force: extra_args.append("--force")
0213     extra_args = [arg for arg in extra_args if arg not in sys.argv]
0214     return "{} {}".format(" ".join(sys.argv), " ".join(extra_args))
0215 
0216 def update_config(args):
0217     # update config path to be output path (in case view is called)
0218     if args.copy:
0219         args.config = "config.pbtxt"
0220         if isinstance(args.copy,str):
0221             args.config = os.path.join(args.copy, args.config)
0222 
0223     with open(args.config,'w') as outfile:
0224         text_format.PrintMessage(args.model, outfile, use_short_repeated_primitives=True)
0225 
0226 def cfg_common(args):
0227     if not hasattr(args,'model_info'):
0228         args.model_info = get_model_info()
0229     args.model = grpc.model_config_pb2.ModelConfig()
0230     if hasattr(args,'config'):
0231         with open(args.config,'r') as infile:
0232             text_format.Parse(infile.read(), args.model)
0233 
0234 def cfg_schema(args):
0235     get_fields(args.model, "ModelConfig", verbose=True)
0236 
0237 def cfg_view(args):
0238     print("Contents of {}".format(args.config))
0239     print_fields(args.model, args.model_info, json=args.json, defaults=args.defaults)
0240 
0241 def cfg_edit(args):
0242     for dest,editor,val in [(dest,editor,getattr(args,dest)) for dest,editor in args.edit_dests.items() if getattr(args,dest) is not None]:
0243         editor(args.model,dest,val)
0244 
0245     update_config(args)
0246 
0247     if args.view:
0248         cfg_view(args)
0249 
0250 def cfg_checksum(args):
0251     # internal parameter
0252     if not hasattr(args, "should_return"):
0253         args.should_return = False
0254 
0255     agents = args.model.model_repository_agents.agents
0256     checksum_agent = next((agent for agent in agents if agent.name=="checksum"), None)
0257     if checksum_agent is None:
0258         checksum_agent = agents.add(name="checksum")
0259 
0260     incorrect = []
0261     missing = []
0262 
0263     from glob import glob
0264     # evaluate symbolic links
0265     config_dir = os.path.realpath(os.path.dirname(args.config))
0266     for filename in glob(os.path.join(config_dir,"*/*")):
0267         # evaluate symbolic links again
0268         filename = os.path.realpath(filename)
0269         checksum = get_checksum(filename)
0270         # key = algorithm:[filename relative to config.pbtxt dir]
0271         filename = os.path.relpath(filename, config_dir)
0272         filekey = "MD5:{}".format(filename)
0273         if filekey in checksum_agent.parameters and checksum!=checksum_agent.parameters[filekey]:
0274             incorrect.append(filename)
0275             if args.update and args.force:
0276                 checksum_agent.parameters[filekey] = checksum
0277         elif filekey not in checksum_agent.parameters:
0278             missing.append(filename)
0279             if args.update:
0280                 checksum_agent.parameters[filekey] = checksum
0281         else:
0282             continue
0283 
0284     needs_update = len(missing)>0
0285     needs_force_update = len(incorrect)>0
0286 
0287     if not args.quiet:
0288         if needs_update:
0289             print("\n".join(["Missing checksums:"]+missing))
0290         if needs_force_update:
0291             print("\n".join(["Incorrect checksums:"]+incorrect))
0292 
0293     if needs_force_update:
0294         if not (args.update and args.force):
0295             if args.should_return:
0296                 return TritonChecksumStatus.INCORRECT
0297             else:
0298                 raise RuntimeError("\n".join([
0299                     "Incorrect checksum(s) found, indicating existing model file(s) has been changed, which violates policy.",
0300                     "To override, run the following command (and provide a justification in your PR):",
0301                     get_checksum_update_cmd(force=True)
0302                 ]))
0303         else:
0304             update_config(args)
0305     elif needs_update:
0306         if not args.update:
0307             if args.should_return:
0308                 return TritonChecksumStatus.MISSING
0309             else:
0310                 raise RuntimeError("\n".join([
0311                     "Missing checksum(s) found, indicating new model file(s).",
0312                     "To update, run the following command:",
0313                     get_checksum_update_cmd(force=False)
0314                 ]))
0315         else:
0316             update_config(args)
0317 
0318     if args.view:
0319         cfg_view(args)
0320 
0321     if args.should_return:
0322         return TritonChecksumStatus.CORRECT
0323 
0324 def cfg_versioncheck(args):
0325     incorrect = []
0326     missing = []
0327 
0328     for path in os.environ['CMSSW_SEARCH_PATH'].split(':'):
0329         if args.verbose: print("Checking: "+path)
0330         for dirpath, dirnames, filenames in os.walk(path, followlinks=True):
0331             for filename in filenames:
0332                 if filename=="config.pbtxt":
0333                     filepath = os.path.join(dirpath,filename)
0334                     if args.verbose: print(filepath)
0335                     checksum_args = Namespace(
0336                         config=filepath, should_return=True,
0337                         copy=False, json=False, defaults=False, view=False,
0338                         update=args.update, force=args.force, quiet=True
0339                     )
0340                     cfg_common(checksum_args)
0341                     status = cfg_checksum(checksum_args)
0342                     if status==TritonChecksumStatus.INCORRECT:
0343                         incorrect.append(filepath)
0344                     elif status==TritonChecksumStatus.MISSING:
0345                         missing.append(filepath)
0346 
0347     msg = []
0348     instr = []
0349     if len(missing)>0:
0350         msg.extend(["","The following files have missing checksum(s), indicating new model file(s):"]+missing)
0351         instr.extend(["","To update missing checksums, run the following command:",get_checksum_update_cmd(force=False)])
0352     if len(incorrect)>0:
0353         msg.extend(["","The following files have incorrect checksum(s), indicating existing model file(s) have been changed, which violates policy:"]+incorrect)
0354         instr.extend(["","To override incorrect checksums, run the following command (and provide a justification in your PR):",get_checksum_update_cmd(force=True)])
0355 
0356     if len(msg)>0:
0357         raise RuntimeError("\n".join(msg+instr))
0358 
0359 def cfg_threadcontrol(args):
0360     # copy the entire model, not just config.pbtxt
0361     config_dir = os.path.dirname(args.config)
0362     copy_dir = args.copy
0363     new_config_dir = os.path.join(copy_dir, pathlib.Path(config_dir).name)
0364     shutil.copytree(config_dir, new_config_dir)
0365 
0366     platform = args.model.platform
0367     if platform=="ensemble":
0368         repo_dir = pathlib.Path(config_dir).parent
0369         for step in args.model.ensemble_scheduling.step:
0370             # update args and run recursively
0371             args.config = os.path.join(repo_dir,step.model_name,"config.pbtxt")
0372             args.copy = copy_dir
0373             cfg_common(args)
0374             cfg_threadcontrol(args)
0375         return
0376 
0377     # is it correct to do this even if found_params is false at the end?
0378     args.model.instance_group.add(count=args.nThreads, kind=kind_to_int['KIND_CPU'])
0379 
0380     found_params = False
0381     for key,val in thread_control_parameters.items():
0382         if key in platform: # partial matching
0383             for param in val:
0384                 item = args.model.parameters.get_or_create(param)
0385                 item.string_value = "1"
0386             found_params = True
0387             break
0388     if not found_params:
0389         print("Warning: thread (instance) control not implemented for {}".format(platform))
0390 
0391     args.copy = new_config_dir
0392     update_config(args)
0393 
0394     if args.view:
0395         cfg_view(args)
0396 
0397 if __name__=="__main__":
0398     # initial common operations
0399     model_info = get_model_info()
0400     edit_dests = None
0401 
0402     _parser_common = ArgumentParser(add_help=False)
0403     _parser_common.add_argument("-c", "--config", type=str, default="", required=True, help="path to input config.pbtxt file")
0404 
0405     parser = ArgumentParser(formatter_class=ArgumentDefaultsRawHelpFormatter)
0406     subparsers = parser.add_subparsers(dest="command")
0407 
0408     parser_schema = subparsers.add_parser("schema", help="view ModelConfig schema",
0409         description="""Display all fields in the ModelConfig object, with type information.
0410     (For collection types, the subfields of the entry type are shown.)""",
0411     )
0412     parser_schema.set_defaults(func=cfg_schema)
0413 
0414     _parser_view_args = ArgumentParser(add_help=False)
0415     _parser_view_args.add_argument("--json", default=False, action="store_true", help="display in json format")
0416     _parser_view_args.add_argument("--defaults", default=False, action="store_true", help="show fields with default values")
0417 
0418     parser_view = subparsers.add_parser("view", parents=[_parser_common, _parser_view_args], help="view config.pbtxt contents")
0419     parser_view.set_defaults(func=cfg_view)
0420 
0421     _parser_copy_view = ArgumentParser(add_help=False)
0422     _parser_copy_view.add_argument("--view", default=False, action="store_true", help="view file after editing")
0423 
0424     _parser_copy = ArgumentParser(add_help=False, parents=[_parser_copy_view])
0425     _parser_copy.add_argument("--copy", metavar="dir", default=False, const=True, nargs='?', type=str,
0426         help="make a copy of config.pbtxt instead of editing in place ([dir] = output path for copy; if omitted, current directory is used)"
0427     )
0428 
0429     parser_edit = subparsers.add_parser("edit", parents=[_parser_common, _parser_copy, _parser_view_args], help="edit config.pbtxt contents")
0430     parser_edit, edit_dests = add_edit_args(parser_edit, model_info)
0431     parser_edit.set_defaults(func=cfg_edit)
0432 
0433     _parser_checksum_update = ArgumentParser(add_help=False)
0434     _parser_checksum_update.add_argument("--update", default=False, action="store_true", help="update missing checksums")
0435     _parser_checksum_update.add_argument("--force", default=False, action="store_true", help="force update all checksums")
0436 
0437     parser_checksum = subparsers.add_parser("checksum", parents=[_parser_common, _parser_copy, _parser_view_args, _parser_checksum_update], help="handle model file checksums")
0438     parser_checksum.add_argument("--quiet", default=False, action="store_true", help="suppress printouts")
0439     parser_checksum.set_defaults(func=cfg_checksum)
0440 
0441     parser_versioncheck = subparsers.add_parser("versioncheck", parents=[_parser_checksum_update], help="check all model checksums")
0442     parser_versioncheck.add_argument("--verbose", default=False, action="store_true", help="verbose output (show all files checked)")
0443     parser_versioncheck.set_defaults(func=cfg_versioncheck)
0444 
0445     _parser_copy_req = ArgumentParser(add_help=False, parents=[_parser_copy_view])
0446     _parser_copy_req.add_argument("--copy", metavar="dir", type=str, required=True,
0447         help="local model repository directory to copy model(s)"
0448     )
0449 
0450     parser_threadcontrol = subparsers.add_parser("threadcontrol", parents=[_parser_common, _parser_copy_req, _parser_view_args], help="enable thread controls")
0451     parser_threadcontrol.add_argument("--nThreads", type=int, required=True, help="number of threads")
0452     parser_threadcontrol.set_defaults(func=cfg_threadcontrol)
0453 
0454     args = parser.parse_args()
0455     args.model_info = model_info
0456     if edit_dests is not None:
0457         args.edit_dests = edit_dests
0458 
0459     cfg_common(args)
0460 
0461     args.func(args)