0001 #!/usr/bin/env python3
0003 import os, sys, json, pathlib, shutil
0004 from collections import OrderedDict
0005 from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, RawTextHelpFormatter, RawDescriptionHelpFormatter, Action, Namespace
0006 from enum import Enum
0007 from google.protobuf import text_format, json_format, message, descriptor
0008 from google.protobuf.internal import type_checkers
0009 from tritonclient import grpc
0011 # convenience definition
0012 # (from ConfigArgParse)
0013 class ArgumentDefaultsRawHelpFormatter(
0014 ArgumentDefaultsHelpFormatter,
0015 RawTextHelpFormatter,
0016 RawDescriptionHelpFormatter):
0017 """HelpFormatter that adds default values AND doesn't do line-wrapping"""
0018 pass
0020 class DictAction(Action):
0021 val_type = None
0022 def __call__(self, parser, namespace, values, option_string=None):
0023 if self.val_type is None:
0024 self.val_type = self.type
0025 result = {}
0026 if len(values)%2!=0:
0027 parser.error("{} args must come in pairs".format(self.dest))
0028 for i in range(0, len(values), 2):
0029 result[values[i]] = self.val_type(values[i+1])
0030 setattr(namespace, self.dest, result)
0032 class TritonChecksumStatus(Enum):
0033 CORRECT = 0
0034 MISSING = 1
0035 INCORRECT = 2
0037 message_classes = {cls.__name__ : cls for cls in message.Message.__subclasses__()}
0039 _FieldDescriptor = descriptor.FieldDescriptor
0040 cpp_to_python = {
0041 _FieldDescriptor.CPPTYPE_INT32: int,
0042 _FieldDescriptor.CPPTYPE_INT64: int,
0043 _FieldDescriptor.CPPTYPE_UINT32: int,
0044 _FieldDescriptor.CPPTYPE_UINT64: int,
0045 _FieldDescriptor.CPPTYPE_DOUBLE: float,
0046 _FieldDescriptor.CPPTYPE_FLOAT: float,
0047 _FieldDescriptor.CPPTYPE_BOOL: bool,
0048 _FieldDescriptor.CPPTYPE_STRING: str,
0049 }
0050 checker_to_type = {val.__class__:cpp_to_python[key] for key,val in type_checkers._VALUE_CHECKERS.items()}
0051 # for some reason, this one is not in the map
0052 checker_to_type[type_checkers.UnicodeValueChecker] = str
0054 kind_to_int = {v.name:v.number for v in grpc.model_config_pb2._MODELINSTANCEGROUP_KIND.values}
0055 thread_control_parameters = {
0056 "onnx": ["intra_op_thread_count", "inter_op_thread_count"],
0058 }
0060 def get_type(obj):
0061 obj_type = obj.__class__.__name__
0062 entry_type = None
0063 entry_class = None
0064 if obj_type=="RepeatedCompositeFieldContainer" or obj_type=="MessageMap":
0065 entry_type = obj._message_descriptor.name
0066 entry_class = message_classes[entry_type]
0067 elif obj_type=="RepeatedScalarFieldContainer":
0068 entry_class = checker_to_type[obj._type_checker.__class__]
0069 entry_type = entry_class.__name__
0070 elif obj_type=="ScalarMap":
0071 entry_class = obj.GetEntryClass()().value.__class__
0072 entry_type = entry_class.__name__
0073 return {
0074 "class": obj.__class__,
0075 "type": obj_type+("<"+entry_type+">" if entry_type is not None else ""),
0076 "entry_class": entry_class,
0077 "entry_type": entry_type,
0078 }
0080 def get_fields(obj, name, level=0, verbose=False):
0081 prefix = ' '*level
0082 obj_info = {"name": name, "fields": []}
0083 obj_info.update(get_type(obj))
0084 if verbose: print(prefix+obj_info["type"],name)
0085 field_obj = None
0086 if hasattr(obj, "DESCRIPTOR"):
0087 field_obj = obj
0088 elif obj_info["entry_class"] is not None and hasattr(obj_info["entry_class"], "DESCRIPTOR"):
0089 field_obj = obj_info["entry_class"]()
0090 field_list = []
0091 if field_obj is not None:
0092 field_list = [f.name for f in field_obj.DESCRIPTOR.fields]
0093 for field in field_list:
0094 obj_info["fields"].append(get_fields(getattr(field_obj,field),field,level+1,verbose))
0095 return obj_info
0097 def get_model_info():
0098 return get_fields(grpc.model_config_pb2.ModelConfig(), "ModelConfig")
0100 def msg_json(val, defaults=False):
0101 return json_format.MessageToJson(val, preserving_proto_field_name=True, including_default_value_fields=defaults, indent=0).replace(",\n",", ").replace("\n","")
0103 def print_fields(obj, info, level=0, json=False, defaults=False):
0104 def print_subfields(obj,level):
0105 fields = obj.DESCRIPTOR.fields if defaults else [f[0] for f in obj.ListFields()]
0106 for field in fields:
0107 print_fields(getattr(obj,field.name), next(f for f in info["fields"] if f["name"]==field.name), level=level, json=json, defaults=defaults)
0109 prefix = ' '
0110 print(prefix*level+info["type"],info["name"])
0111 if hasattr(obj, "DESCRIPTOR"):
0112 if json and level>0:
0113 print(prefix*(level+1)+msg_json(obj, defaults))
0114 else:
0115 print_subfields(obj,level+1)
0116 elif info["type"].startswith("RepeatedCompositeFieldContainer"):
0117 if json:
0118 print(prefix*(level+1)+str([msg_json(val, defaults) for val in obj]))
0119 else:
0120 for ientry,entry in enumerate(obj):
0121 print(prefix*(level+1)+"{}: ".format(ientry))
0122 print_subfields(entry,level+2)
0123 elif info["type"].startswith("MessageMap"):
0124 if json:
0125 print(prefix*(level+1)+str({key:msg_json(val, defaults) for key,val in obj.items()}))
0126 else:
0127 for key,val in obj.items():
0128 print(prefix*(level+1)+"{}: ".format(key))
0129 print_subfields(val,level+2)
0130 else:
0131 print(prefix*(level+1)+str(obj))
0133 def edit_builtin(model,dest,val):
0134 setattr(model,dest,val)
0136 def edit_scalar_list(model,dest,val):
0137 item = getattr(model,dest)
0138 item.clear()
0139 item.extend(val)
0141 def edit_scalar_map(model,dest,val):
0142 item = getattr(model,dest)
0143 item.clear()
0144 item.update(val)
0146 def edit_msg(model,dest,val):
0147 item = getattr(model,dest)
0148 json_format.ParseDict(val,item)
0150 def edit_msg_list(model,dest,val):
0151 item = getattr(model,dest)
0152 item.clear()
0153 for v in vals:
0154 m = item.add()
0155 json_format.ParseDict(v,m)
0157 def edit_msg_map(model,dest,val):
0158 item = getattr(model,dest)
0159 item.clear()
0160 for k,v in vals.items():
0161 m = item.get_or_create(k)
0162 json_format.ParseDict(v,m)
0164 def add_edit_args(parser, model_info):
0165 group = parser.add_argument_group("fields", description="ModelConfig fields to edit")
0166 dests = {}
0167 for field in model_info["fields"]:
0168 argname = "--{}".format(field["name"].replace("_","-"))
0169 val_type = None
0170 editor = None
0171 if field["class"].__module__=="builtins":
0172 kwargs = dict(type=field["class"])
0173 editor = edit_builtin
0174 elif field["type"].startswith("RepeatedScalarFieldContainer"):
0175 kwargs = dict(type=field["entry_class"], nargs='*')
0176 editor = edit_scalar_list
0177 elif field["type"].startswith("ScalarMap"):
0178 kwargs = dict(type=str, nargs='*', metavar="key value", action=DictAction)
0179 val_type = field["entry_class"]
0180 editor = edit_scalar_map
0181 elif field["type"].startswith("RepeatedCompositeFieldContainer"):
0182 kwargs = dict(type=json.loads, nargs='*',
0183 help="provide {} values in json format".format(field["entry_type"])
0184 )
0185 editor = edit_msg_list
0186 elif field["type"].startswith("MessageMap"):
0187 kwargs = dict(type=str, nargs='*', metavar="key value", action=DictAction,
0188 help="provide {} values in json format".format(field["entry_type"])
0189 )
0190 editor = edit_msg_map
0191 val_type = json.loads
0192 else:
0193 kwargs = dict(type=json.loads,
0194 help="provide {} values in json format".format(field["type"])
0195 )
0196 edit = edit_msg
0197 action = group.add_argument(argname, **kwargs)
0198 if val_type is not None: action.val_type = val_type
0199 dests[action.dest] = editor
0200 return parser, dests
0202 def get_checksum(filename, chunksize=4096):
0203 import hashlib
0204 with open(filename, 'rb') as f:
0205 file_hash = hashlib.md5()
0206 while chunk := f.read(chunksize):
0207 file_hash.update(chunk)
0208 return file_hash.hexdigest()
0210 def get_checksum_update_cmd(force=False):
0211 extra_args = ["--update"]
0212 if force: extra_args.append("--force")
0213 extra_args = [arg for arg in extra_args if arg not in sys.argv]
0214 return "{} {}".format(" ".join(sys.argv), " ".join(extra_args))
0216 def update_config(args):
0217 # update config path to be output path (in case view is called)
0218 if args.copy:
0219 args.config = "config.pbtxt"
0220 if isinstance(args.copy,str):
0221 args.config = os.path.join(args.copy, args.config)
0223 with open(args.config,'w') as outfile:
0224 text_format.PrintMessage(args.model, outfile, use_short_repeated_primitives=True)
0226 def cfg_common(args):
0227 if not hasattr(args,'model_info'):
0228 args.model_info = get_model_info()
0229 args.model = grpc.model_config_pb2.ModelConfig()
0230 if hasattr(args,'config'):
0231 with open(args.config,'r') as infile:
0232 text_format.Parse(infile.read(), args.model)
0234 def cfg_schema(args):
0235 get_fields(args.model, "ModelConfig", verbose=True)
0237 def cfg_view(args):
0238 print("Contents of {}".format(args.config))
0239 print_fields(args.model, args.model_info, json=args.json, defaults=args.defaults)
0241 def cfg_edit(args):
0242 for dest,editor,val in [(dest,editor,getattr(args,dest)) for dest,editor in args.edit_dests.items() if getattr(args,dest) is not None]:
0243 editor(args.model,dest,val)
0245 update_config(args)
0247 if args.view:
0248 cfg_view(args)
0250 def cfg_checksum(args):
0251 # internal parameter
0252 if not hasattr(args, "should_return"):
0253 args.should_return = False
0255 agents = args.model.model_repository_agents.agents
0256 checksum_agent = next((agent for agent in agents if agent.name=="checksum"), None)
0257 if checksum_agent is None:
0258 checksum_agent = agents.add(name="checksum")
0260 incorrect = []
0261 missing = []
0263 from glob import glob
0264 # evaluate symbolic links
0265 config_dir = os.path.realpath(os.path.dirname(args.config))
0266 for filename in glob(os.path.join(config_dir,"*/*")):
0267 # evaluate symbolic links again
0268 filename = os.path.realpath(filename)
0269 checksum = get_checksum(filename)
0270 # key = algorithm:[filename relative to config.pbtxt dir]
0271 filename = os.path.relpath(filename, config_dir)
0272 filekey = "MD5:{}".format(filename)
0273 if filekey in checksum_agent.parameters and checksum!=checksum_agent.parameters[filekey]:
0274 incorrect.append(filename)
0275 if args.update and args.force:
0276 checksum_agent.parameters[filekey] = checksum
0277 elif filekey not in checksum_agent.parameters:
0278 missing.append(filename)
0279 if args.update:
0280 checksum_agent.parameters[filekey] = checksum
0281 else:
0282 continue
0284 needs_update = len(missing)>0
0285 needs_force_update = len(incorrect)>0
0287 if not args.quiet:
0288 if needs_update:
0289 print("\n".join(["Missing checksums:"]+missing))
0290 if needs_force_update:
0291 print("\n".join(["Incorrect checksums:"]+incorrect))
0293 if needs_force_update:
0294 if not (args.update and args.force):
0295 if args.should_return:
0296 return TritonChecksumStatus.INCORRECT
0297 else:
0298 raise RuntimeError("\n".join([
0299 "Incorrect checksum(s) found, indicating existing model file(s) has been changed, which violates policy.",
0300 "To override, run the following command (and provide a justification in your PR):",
0301 get_checksum_update_cmd(force=True)
0302 ]))
0303 else:
0304 update_config(args)
0305 elif needs_update:
0306 if not args.update:
0307 if args.should_return:
0308 return TritonChecksumStatus.MISSING
0309 else:
0310 raise RuntimeError("\n".join([
0311 "Missing checksum(s) found, indicating new model file(s).",
0312 "To update, run the following command:",
0313 get_checksum_update_cmd(force=False)
0314 ]))
0315 else:
0316 update_config(args)
0318 if args.view:
0319 cfg_view(args)
0321 if args.should_return:
0322 return TritonChecksumStatus.CORRECT
0324 def cfg_versioncheck(args):
0325 incorrect = []
0326 missing = []
0328 for path in os.environ['CMSSW_SEARCH_PATH'].split(':'):
0329 if args.verbose: print("Checking: "+path)
0330 for dirpath, dirnames, filenames in os.walk(path, followlinks=True):
0331 for filename in filenames:
0332 if filename=="config.pbtxt":
0333 filepath = os.path.join(dirpath,filename)
0334 if args.verbose: print(filepath)
0335 checksum_args = Namespace(
0336 config=filepath, should_return=True,
0337 copy=False, json=False, defaults=False, view=False,
0338 update=args.update, force=args.force, quiet=True
0339 )
0340 cfg_common(checksum_args)
0341 status = cfg_checksum(checksum_args)
0342 if status==TritonChecksumStatus.INCORRECT:
0343 incorrect.append(filepath)
0344 elif status==TritonChecksumStatus.MISSING:
0345 missing.append(filepath)
0347 msg = []
0348 instr = []
0349 if len(missing)>0:
0350 msg.extend(["","The following files have missing checksum(s), indicating new model file(s):"]+missing)
0351 instr.extend(["","To update missing checksums, run the following command:",get_checksum_update_cmd(force=False)])
0352 if len(incorrect)>0:
0353 msg.extend(["","The following files have incorrect checksum(s), indicating existing model file(s) have been changed, which violates policy:"]+incorrect)
0354 instr.extend(["","To override incorrect checksums, run the following command (and provide a justification in your PR):",get_checksum_update_cmd(force=True)])
0356 if len(msg)>0:
0357 raise RuntimeError("\n".join(msg+instr))
0359 def cfg_threadcontrol(args):
0360 # copy the entire model, not just config.pbtxt
0361 config_dir = os.path.dirname(args.config)
0362 copy_dir = args.copy
0363 new_config_dir = os.path.join(copy_dir, pathlib.Path(config_dir).name)
0364 shutil.copytree(config_dir, new_config_dir)
0366 platform = args.model.platform
0367 if platform=="ensemble":
0368 repo_dir = pathlib.Path(config_dir).parent
0369 for step in args.model.ensemble_scheduling.step:
0370 # update args and run recursively
0371 args.config = os.path.join(repo_dir,step.model_name,"config.pbtxt")
0372 args.copy = copy_dir
0373 cfg_common(args)
0374 cfg_threadcontrol(args)
0375 return
0377 # is it correct to do this even if found_params is false at the end?
0378 args.model.instance_group.add(count=args.nThreads, kind=kind_to_int['KIND_CPU'])
0380 found_params = False
0381 for key,val in thread_control_parameters.items():
0382 if key in platform: # partial matching
0383 for param in val:
0384 item = args.model.parameters.get_or_create(param)
0385 item.string_value = "1"
0386 found_params = True
0387 break
0388 if not found_params:
0389 print("Warning: thread (instance) control not implemented for {}".format(platform))
0391 args.copy = new_config_dir
0392 update_config(args)
0394 if args.view:
0395 cfg_view(args)
0397 if __name__=="__main__":
0398 # initial common operations
0399 model_info = get_model_info()
0400 edit_dests = None
0402 _parser_common = ArgumentParser(add_help=False)
0403 _parser_common.add_argument("-c", "--config", type=str, default="", required=True, help="path to input config.pbtxt file")
0405 parser = ArgumentParser(formatter_class=ArgumentDefaultsRawHelpFormatter)
0406 subparsers = parser.add_subparsers(dest="command")
0408 parser_schema = subparsers.add_parser("schema", help="view ModelConfig schema",
0409 description="""Display all fields in the ModelConfig object, with type information.
0410 (For collection types, the subfields of the entry type are shown.)""",
0411 )
0412 parser_schema.set_defaults(func=cfg_schema)
0414 _parser_view_args = ArgumentParser(add_help=False)
0415 _parser_view_args.add_argument("--json", default=False, action="store_true", help="display in json format")
0416 _parser_view_args.add_argument("--defaults", default=False, action="store_true", help="show fields with default values")
0418 parser_view = subparsers.add_parser("view", parents=[_parser_common, _parser_view_args], help="view config.pbtxt contents")
0419 parser_view.set_defaults(func=cfg_view)
0421 _parser_copy_view = ArgumentParser(add_help=False)
0422 _parser_copy_view.add_argument("--view", default=False, action="store_true", help="view file after editing")
0424 _parser_copy = ArgumentParser(add_help=False, parents=[_parser_copy_view])
0425 _parser_copy.add_argument("--copy", metavar="dir", default=False, const=True, nargs='?', type=str,
0426 help="make a copy of config.pbtxt instead of editing in place ([dir] = output path for copy; if omitted, current directory is used)"
0427 )
0429 parser_edit = subparsers.add_parser("edit", parents=[_parser_common, _parser_copy, _parser_view_args], help="edit config.pbtxt contents")
0430 parser_edit, edit_dests = add_edit_args(parser_edit, model_info)
0431 parser_edit.set_defaults(func=cfg_edit)
0433 _parser_checksum_update = ArgumentParser(add_help=False)
0434 _parser_checksum_update.add_argument("--update", default=False, action="store_true", help="update missing checksums")
0435 _parser_checksum_update.add_argument("--force", default=False, action="store_true", help="force update all checksums")
0437 parser_checksum = subparsers.add_parser("checksum", parents=[_parser_common, _parser_copy, _parser_view_args, _parser_checksum_update], help="handle model file checksums")
0438 parser_checksum.add_argument("--quiet", default=False, action="store_true", help="suppress printouts")
0439 parser_checksum.set_defaults(func=cfg_checksum)
0441 parser_versioncheck = subparsers.add_parser("versioncheck", parents=[_parser_checksum_update], help="check all model checksums")
0442 parser_versioncheck.add_argument("--verbose", default=False, action="store_true", help="verbose output (show all files checked)")
0443 parser_versioncheck.set_defaults(func=cfg_versioncheck)
0445 _parser_copy_req = ArgumentParser(add_help=False, parents=[_parser_copy_view])
0446 _parser_copy_req.add_argument("--copy", metavar="dir", type=str, required=True,
0447 help="local model repository directory to copy model(s)"
0448 )
0450 parser_threadcontrol = subparsers.add_parser("threadcontrol", parents=[_parser_common, _parser_copy_req, _parser_view_args], help="enable thread controls")
0451 parser_threadcontrol.add_argument("--nThreads", type=int, required=True, help="number of threads")
0452 parser_threadcontrol.set_defaults(func=cfg_threadcontrol)
0454 args = parser.parse_args()
0455 args.model_info = model_info
0456 if edit_dests is not None:
0457 args.edit_dests = edit_dests
0459 cfg_common(args)
0461 args.func(args)