File indexing completed on 2021-08-12 23:11:47
0001 from __future__ import print_function
0002 import FWCore.ParameterSet.Config as cms
0003 import sys
0004
0005
0006
0007
0008
0009
0010 from FWCore.ParameterSet.MassReplace import massSearchReplaceAnyInputTag, MassSearchReplaceAnyInputTagVisitor
0011 from FWCore.ParameterSet.MassReplace import massSearchReplaceParam, MassSearchParamVisitor, MassSearchReplaceParamVisitor
0012
0013 def getPatAlgosToolsTask(process):
0014 taskName = "patAlgosToolsTask"
0015 if hasattr(process, taskName):
0016 task = getattr(process, taskName)
0017 if not isinstance(task, cms.Task):
0018 raise Exception("patAlgosToolsTask does not have type Task")
0019 else:
0020 setattr(process, taskName, cms.Task())
0021 task = getattr(process, taskName)
0022 return task
0023
0024 def associatePatAlgosToolsTask(process):
0025 task = getPatAlgosToolsTask(process)
0026 process.schedule.associate(task)
0027
0028 def addToProcessAndTask(label, module, process, task):
0029 setattr(process, label, module)
0030 task.add(getattr(process, label))
0031
0032 def addESProducers(process,config):
0033 config = config.replace("/",".")
0034
0035
0036 module = __import__(config)
0037 for name in dir(sys.modules[config]):
0038 item = getattr(sys.modules[config],name)
0039 if isinstance(item,cms._Labelable) and not isinstance(item,cms._ModuleSequenceType) and not name.startswith('_') and not (name == "source" or name == "looper" or name == "subProcess") and not isinstance(item, cms.PSet):
0040 if 'ESProducer' in item.type_():
0041 setattr(process,name,item)
0042
0043 def loadWithPrefix(process,moduleName,prefix='',loadedProducersAndFilters=None):
0044 loadWithPrePostfix(process,moduleName,prefix,'',loadedProducersAndFilters)
0045
0046 def loadWithPostfix(process,moduleName,postfix='',loadedProducersAndFilters=None):
0047 loadWithPrePostfix(process,moduleName,'',postfix,loadedProducersAndFilters)
0048
0049 def loadWithPrePostfix(process,moduleName,prefix='',postfix='',loadedProducersAndFilters=None):
0050 moduleName = moduleName.replace("/",".")
0051 module = __import__(moduleName)
0052
0053 extendWithPrePostfix(process,sys.modules[moduleName],prefix,postfix,loadedProducersAndFilters)
0054
0055 def addToTask(loadedProducersAndFilters, module):
0056 if loadedProducersAndFilters:
0057 if isinstance(module, cms.EDProducer) or isinstance(module, cms.EDFilter):
0058 loadedProducersAndFilters.add(module)
0059
0060 def extendWithPrePostfix(process,other,prefix,postfix,loadedProducersAndFilters=None):
0061 """Look in other and find types which we can use"""
0062
0063
0064
0065 if loadedProducersAndFilters:
0066 task = getattr(process, loadedProducersAndFilters)
0067 if not isinstance(task, cms.Task):
0068 raise Exception("extendWithPrePostfix argument must be name of Task type object attached to the process or None")
0069 else:
0070 task = None
0071
0072 sequence = cms.Sequence()
0073 sequence._moduleLabels = []
0074 for name in dir(other):
0075
0076 if name.startswith('_'):
0077 continue
0078 item = getattr(other,name)
0079 if name == "source" or name == "looper" or name == "subProcess":
0080 continue
0081 elif isinstance(item,cms._ModuleSequenceType):
0082 continue
0083 elif isinstance(item,cms.Task):
0084 continue
0085 elif isinstance(item,cms.Schedule):
0086 continue
0087 elif isinstance(item,cms.VPSet) or isinstance(item,cms.PSet):
0088 continue
0089 elif isinstance(item,cms._Labelable):
0090 if not item.hasLabel_():
0091 item.setLabel(name)
0092 if prefix != '' or postfix != '':
0093 newModule = item.clone()
0094 if isinstance(item,cms.ESProducer):
0095 newName =name
0096 else:
0097 if 'TauDiscrimination' in name:
0098 process.__setattr__(name,item)
0099 addToTask(task, item)
0100 newName = prefix+name+postfix
0101 process.__setattr__(newName,newModule)
0102 addToTask(task, newModule)
0103 if isinstance(newModule, cms._Sequenceable) and not newName == name:
0104 sequence +=getattr(process,newName)
0105 sequence._moduleLabels.append(item.label())
0106 else:
0107 process.__setattr__(name,item)
0108 addToTask(task, item)
0109
0110 if prefix != '' or postfix != '':
0111 for label in sequence._moduleLabels:
0112 massSearchReplaceAnyInputTag(sequence, label, prefix+label+postfix,verbose=False,moduleLabelOnly=True)
0113
0114 def applyPostfix(process, label, postfix):
0115 result = None
0116 if hasattr(process, label+postfix):
0117 result = getattr(process, label + postfix)
0118 else:
0119 raise ValueError("Error in <applyPostfix>: No module of name = %s attached to process !!" % (label + postfix))
0120 return result
0121
0122 def removeIfInSequence(process, target, sequenceLabel, postfix=""):
0123 labels = __labelsInSequence(process, sequenceLabel, postfix, True)
0124 if target+postfix in labels:
0125 getattr(process, sequenceLabel+postfix).remove(
0126 getattr(process, target+postfix)
0127 )
0128
0129 def __labelsInSequence(process, sequenceLabel, postfix="", keepPostFix=False):
0130 position = -len(postfix)
0131 if keepPostFix:
0132 position = None
0133
0134 result = [ m.label()[:position] for m in listModules( getattr(process,sequenceLabel+postfix))]
0135 result.extend([ m.label()[:position] for m in listSequences( getattr(process,sequenceLabel+postfix))] )
0136 if postfix == "":
0137 result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
0138 result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))] )
0139 return result
0140
0141
0142 class GatherAllModulesVisitor(object):
0143 """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
0144 def __init__(self, gatheredInstance=cms._Module):
0145 self._modules = []
0146 self._gatheredInstance= gatheredInstance
0147 def enter(self,visitee):
0148 if isinstance(visitee,self._gatheredInstance):
0149 self._modules.append(visitee)
0150 def leave(self,visitee):
0151 pass
0152 def modules(self):
0153 return self._modules
0154
0155 class CloneSequenceVisitor(object):
0156 """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
0157 All modules and sequences are cloned and a postfix is added"""
0158 def __init__(self, process, label, postfix, removePostfix="", noClones = [], addToTask = False, verbose = False):
0159 self._process = process
0160 self._postfix = postfix
0161 self._removePostfix = removePostfix
0162 self._noClones = noClones
0163 self._addToTask = addToTask
0164 self._verbose = verbose
0165 self._moduleLabels = []
0166 self._clonedSequence = cms.Sequence()
0167 setattr(process, self._newLabel(label), self._clonedSequence)
0168 if addToTask:
0169 self._patAlgosToolsTask = getPatAlgosToolsTask(process)
0170
0171 def enter(self, visitee):
0172 if isinstance(visitee, cms._Module):
0173 label = visitee.label()
0174 newModule = None
0175 if label in self._noClones:
0176 newModule = getattr(self._process, label)
0177 elif label in self._moduleLabels:
0178 newModule = getattr(self._process, self._newLabel(label))
0179 else:
0180 self._moduleLabels.append(label)
0181 newModule = visitee.clone()
0182 setattr(self._process, self._newLabel(label), newModule)
0183 if self._addToTask:
0184 self._patAlgosToolsTask.add(getattr(self._process, self._newLabel(label)))
0185 self.__appendToTopSequence(newModule)
0186
0187 def leave(self, visitee):
0188 pass
0189
0190 def clonedSequence(self):
0191 for label in self._moduleLabels:
0192 massSearchReplaceAnyInputTag(self._clonedSequence, label, self._newLabel(label), moduleLabelOnly=True, verbose=self._verbose)
0193 self._moduleLabels = []
0194 return self._clonedSequence
0195
0196 def _newLabel(self, label):
0197 if self._removePostfix != "":
0198 if label[-len(self._removePostfix):] == self._removePostfix:
0199 label = label[0:-len(self._removePostfix)]
0200 else:
0201 raise Exception("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
0202 return label + self._postfix
0203
0204 def __appendToTopSequence(self, visitee):
0205 self._clonedSequence += visitee
0206
0207 def listModules(sequence):
0208 visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
0209 sequence.visit(visitor)
0210 return visitor.modules()
0211
0212 def listSequences(sequence):
0213 visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
0214 sequence.visit(visitor)
0215 return visitor.modules()
0216
0217 def jetCollectionString(prefix='', algo='', type=''):
0218 """
0219 ------------------------------------------------------------------
0220 return the string of the jet collection module depending on the
0221 input vaules. The default return value will be 'patAK5CaloJets'.
0222
0223 algo : indicating the algorithm type of the jet [expected are
0224 'AK5', 'IC5', 'SC7', ...]
0225 type : indicating the type of constituents of the jet [expec-
0226 ted are 'Calo', 'PFlow', 'JPT', ...]
0227 prefix : prefix indicating the type of pat collection module (ex-
0228 pected are '', 'selected', 'clean').
0229 ------------------------------------------------------------------
0230 """
0231 if(prefix==''):
0232 jetCollectionString ='pat'
0233 else:
0234 jetCollectionString =prefix
0235 jetCollectionString+='Pat'
0236 jetCollectionString+='Jets'
0237 jetCollectionString+=algo
0238 jetCollectionString+=type
0239 return jetCollectionString
0240
0241 def contains(sequence, moduleName):
0242 """
0243 ------------------------------------------------------------------
0244 return True if a module with name 'module' is contained in the
0245 sequence with name 'sequence' and False otherwise. This version
0246 is not so nice as it also returns True for any substr of the name
0247 of a contained module.
0248
0249 sequence : sequence [e.g. process.patDefaultSequence]
0250 module : module name as a string
0251 ------------------------------------------------------------------
0252 """
0253 return not sequence.__str__().find(moduleName)==-1
0254
0255
0256
0257 def cloneProcessingSnippet(process, sequence, postfix, removePostfix="", noClones = [], addToTask = False, verbose = False):
0258 """
0259 ------------------------------------------------------------------
0260 copy a sequence plus the modules and sequences therein
0261 both are renamed by getting a postfix
0262 input tags are automatically adjusted
0263 ------------------------------------------------------------------
0264 """
0265 result = sequence
0266 if not postfix == "":
0267 visitor = CloneSequenceVisitor(process, sequence.label(), postfix, removePostfix, noClones, addToTask, verbose)
0268 sequence.visit(visitor)
0269 result = visitor.clonedSequence()
0270 return result
0271
0272 def listDependencyChain(process, module, sources, verbose=False):
0273 """
0274 Walk up the dependencies of a module to find any that depend on any of the listed sources
0275 """
0276 def allDirectInputModules(moduleOrPSet,moduleName,attrName):
0277 ret = set()
0278 for name,value in moduleOrPSet.parameters_().items():
0279 type = value.pythonTypeName()
0280 if type == 'cms.PSet':
0281 ret.update(allDirectInputModules(value,moduleName,moduleName+"."+name))
0282 elif type == 'cms.VPSet':
0283 for (i,ps) in enumerate(value):
0284 ret.update(allDirectInputModules(ps,moduleName,"%s.%s[%d]"%(moduleName,name,i)))
0285 elif type == 'cms.VInputTag':
0286 inputs = [ MassSearchReplaceAnyInputTagVisitor.standardizeInputTagFmt(it) for it in value ]
0287 inputLabels = [ tag.moduleLabel for tag in inputs if tag.processName == '' or tag.processName == process.name_() ]
0288 ret.update(inputLabels)
0289 if verbose and inputLabels: print("%s depends on %s via %s" % (moduleName, inputLabels, attrName+"."+name))
0290 elif type.endswith('.InputTag'):
0291 if value.processName == '' or value.processName == process.name_():
0292 ret.add(value.moduleLabel)
0293 if verbose: print("%s depends on %s via %s" % (moduleName, value.moduleLabel, attrName+"."+name))
0294 ret.discard("")
0295 return ret
0296 def fillDirectDepGraphs(root,fwdepgraph,revdepgraph):
0297 if root.label_() in fwdepgraph: return
0298 deps = allDirectInputModules(root,root.label_(),root.label_())
0299 fwdepgraph[root.label_()] = []
0300 for d in deps:
0301 fwdepgraph[root.label_()].append(d)
0302 if d not in revdepgraph: revdepgraph[d] = []
0303 revdepgraph[d].append(root.label_())
0304 depmodule = getattr(process,d,None)
0305 if depmodule:
0306 fillDirectDepGraphs(depmodule,fwdepgraph,revdepgraph)
0307 return (fwdepgraph,revdepgraph)
0308 fwdepgraph, revdepgraph = fillDirectDepGraphs(module, {}, {})
0309 def flattenRevDeps(flatgraph, revdepgraph, tip):
0310 """Make a graph that for each module lists all the ones that depend on it, directly or indirectly"""
0311
0312 if tip in flatgraph: return
0313
0314 if tip not in revdepgraph: return
0315
0316 mydeps = set()
0317
0318 for d in revdepgraph[tip]:
0319
0320 flattenRevDeps(flatgraph, revdepgraph, d)
0321
0322 mydeps.add(d)
0323 if d in flatgraph:
0324 mydeps.update(flatgraph[d])
0325 flatgraph[tip] = mydeps
0326 flatdeps = {}
0327 allmodules = set()
0328 for s in sources:
0329 flattenRevDeps(flatdeps, revdepgraph, s)
0330 if s in flatdeps: allmodules.update(f for f in flatdeps[s])
0331 livemodules = [ a for a in allmodules if hasattr(process,a) ]
0332 if not livemodules: return None
0333 modulelist = [livemodules.pop()]
0334 for module in livemodules:
0335 for i,m in enumerate(modulelist):
0336 if module in flatdeps and m in flatdeps[module]:
0337 modulelist.insert(i, module)
0338 break
0339 if module not in modulelist:
0340 modulelist.append(module)
0341
0342 for i,m1 in enumerate(modulelist):
0343 for j,m2 in enumerate(modulelist):
0344 if j <= i: continue
0345 if m2 in flatdeps and m1 in flatdeps[m2]:
0346 raise RuntimeError("BAD ORDER %s BEFORE %s" % (m1,m2))
0347 modules = [ getattr(process,p) for p in modulelist ]
0348
0349 task = cms.Task()
0350 for mod in modules:
0351 task.add(mod)
0352 return task,cms.Sequence(task)
0353
0354 def addKeepStatement(process, oldKeep, newKeeps, verbose=False):
0355 """Add new keep statements to any PoolOutputModule of the process that has the old keep statements"""
0356 for name,out in process.outputModules.items():
0357 if out.type_() == 'PoolOutputModule' and hasattr(out, "outputCommands"):
0358 if oldKeep in out.outputCommands:
0359 out.outputCommands += newKeeps
0360 if verbose:
0361 print("Adding the following keep statements to output module %s: " % name)
0362 for k in newKeeps: print("\t'%s'," % k)
0363
0364
0365 if __name__=="__main__":
0366 import unittest
0367 def _lineDiff(newString, oldString):
0368 newString = ( x for x in newString.split('\n') if len(x) > 0)
0369 oldString = [ x for x in oldString.split('\n') if len(x) > 0]
0370 diff = []
0371 oldStringLine = 0
0372 for l in newString:
0373 if oldStringLine >= len(oldString):
0374 diff.append(l)
0375 continue
0376 if l == oldString[oldStringLine]:
0377 oldStringLine +=1
0378 continue
0379 diff.append(l)
0380 return "\n".join( diff )
0381
0382 class TestModuleCommand(unittest.TestCase):
0383 def setUp(self):
0384 """Nothing to do """
0385 pass
0386 def testCloning(self):
0387 p = cms.Process("test")
0388 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0389 p.b = cms.EDProducer("b", src=cms.InputTag("a"))
0390 p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
0391 p.s = cms.Sequence(p.a*p.b*p.c *p.a)
0392 cloneProcessingSnippet(p, p.s, "New", addToTask = True)
0393 self.assertEqual(_lineDiff(p.dumpPython(), cms.Process("test").dumpPython()),
0394 """process.a = cms.EDProducer("a",
0395 src = cms.InputTag("gen")
0396 )
0397 process.aNew = cms.EDProducer("a",
0398 src = cms.InputTag("gen")
0399 )
0400 process.b = cms.EDProducer("b",
0401 src = cms.InputTag("a")
0402 )
0403 process.bNew = cms.EDProducer("b",
0404 src = cms.InputTag("aNew")
0405 )
0406 process.c = cms.EDProducer("c",
0407 src = cms.InputTag("b","instance")
0408 )
0409 process.cNew = cms.EDProducer("c",
0410 src = cms.InputTag("bNew","instance")
0411 )
0412 process.patAlgosToolsTask = cms.Task(process.aNew, process.bNew, process.cNew)
0413 process.s = cms.Sequence(process.a+process.b+process.c+process.a)
0414 process.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew+process.aNew)""")
0415 def testContains(self):
0416 p = cms.Process("test")
0417 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0418 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0419 p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
0420 p.s1 = cms.Sequence(p.a*p.b*p.c)
0421 p.s2 = cms.Sequence(p.b*p.c)
0422 self.assertTrue( contains(p.s1, "a") )
0423 self.assertTrue( not contains(p.s2, "a") )
0424 def testJetCollectionString(self):
0425 self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patJetsFooBar')
0426 self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatJetsFooBar')
0427 def testListModules(self):
0428 p = cms.Process("test")
0429 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0430 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0431 p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
0432 p.s = cms.Sequence(p.a*p.b*p.c)
0433 self.assertEqual([p.a,p.b,p.c], listModules(p.s))
0434
0435 unittest.main()
0436
0437 class CloneTaskVisitor(object):
0438 """Visitor that travels within a cms.Task, and returns a cloned version of the Task.
0439 All modules are cloned and a postfix is added"""
0440 def __init__(self, process, label, postfix, removePostfix="", noClones = [], verbose = False):
0441 self._process = process
0442 self._postfix = postfix
0443 self._removePostfix = removePostfix
0444 self._noClones = noClones
0445 self._verbose = verbose
0446 self._moduleLabels = []
0447 self._clonedTask = cms.Task()
0448 setattr(process, self._newLabel(label), self._clonedTask)
0449
0450 def enter(self, visitee):
0451 if isinstance(visitee, cms._Module):
0452 label = visitee.label()
0453 newModule = None
0454 if label in self._noClones:
0455 newModule = getattr(self._process, label)
0456 elif label in self._moduleLabels:
0457 newModule = getattr(self._process, self._newLabel(label))
0458 else:
0459 self._moduleLabels.append(label)
0460 newModule = visitee.clone()
0461 setattr(self._process, self._newLabel(label), newModule)
0462 self.__appendToTopTask(newModule)
0463
0464 def leave(self, visitee):
0465 pass
0466
0467 def clonedTask(self):
0468 for label in self._moduleLabels:
0469 massSearchReplaceAnyInputTag(self._clonedTask, label, self._newLabel(label), moduleLabelOnly=True, verbose=self._verbose)
0470 self._moduleLabels = []
0471 return self._clonedTask
0472
0473 def _newLabel(self, label):
0474 if self._removePostfix != "":
0475 if label[-len(self._removePostfix):] == self._removePostfix:
0476 label = label[0:-len(self._removePostfix)]
0477 else:
0478 raise Exception("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
0479 return label + self._postfix
0480
0481 def __appendToTopTask(self, visitee):
0482 self._clonedTask.add(visitee)
0483
0484 def cloneProcessingSnippetTask(process, task, postfix, removePostfix="", noClones = [], verbose = False):
0485 """
0486 ------------------------------------------------------------------
0487 copy a task plus the modules therein (including modules in subtasks)
0488 both are renamed by getting a postfix
0489 input tags are automatically adjusted
0490 ------------------------------------------------------------------
0491 """
0492 result = task
0493 if not postfix == "":
0494 visitor = CloneTaskVisitor(process, task.label(), postfix, removePostfix, noClones, verbose)
0495 task.visit(visitor)
0496 result = visitor.clonedTask()
0497 return result