File indexing completed on 2023-03-17 11:16:24
0001 from __future__ import print_function
0002 import FWCore.ParameterSet.Config as cms
0003 import sys
0004
0005
0006
0007
0008
0009
0010 from FWCore.ParameterSet.MassReplace import massSearchReplaceAnyInputTag, MassSearchReplaceAnyInputTagVisitor
0011 from FWCore.ParameterSet.MassReplace import massSearchReplaceParam, MassSearchParamVisitor, MassSearchReplaceParamVisitor
0012
0013 def getPatAlgosToolsTask(process):
0014 taskName = "patAlgosToolsTask"
0015 if hasattr(process, taskName):
0016 task = getattr(process, taskName)
0017 if not isinstance(task, cms.Task):
0018 raise Exception("patAlgosToolsTask does not have type Task")
0019 else:
0020 setattr(process, taskName, cms.Task())
0021 task = getattr(process, taskName)
0022 return task
0023
0024 def associatePatAlgosToolsTask(process):
0025 task = getPatAlgosToolsTask(process)
0026 process.schedule.associate(task)
0027
0028 def addToProcessAndTask(label, module, process, task):
0029 setattr(process, label, module)
0030 task.add(getattr(process, label))
0031
0032 def addTaskToProcess(process, label, task):
0033 if not hasattr(process, label):
0034 setattr(process, label, task)
0035 else:
0036 getattr(process, label).add(task)
0037
0038 def addESProducers(process,config):
0039 config = config.replace("/",".")
0040
0041
0042 module = __import__(config)
0043 for name in dir(sys.modules[config]):
0044 item = getattr(sys.modules[config],name)
0045 if isinstance(item,cms._Labelable) and not isinstance(item,cms._ModuleSequenceType) and not name.startswith('_') and not (name == "source" or name == "looper" or name == "subProcess") and not isinstance(item, cms.PSet):
0046 if 'ESProducer' in item.type_():
0047 setattr(process,name,item)
0048
0049 def loadWithPrefix(process,moduleName,prefix='',loadedProducersAndFilters=None):
0050 loadWithPrePostfix(process,moduleName,prefix,'',loadedProducersAndFilters)
0051
0052 def loadWithPostfix(process,moduleName,postfix='',loadedProducersAndFilters=None):
0053 loadWithPrePostfix(process,moduleName,'',postfix,loadedProducersAndFilters)
0054
0055 def loadWithPrePostfix(process,moduleName,prefix='',postfix='',loadedProducersAndFilters=None):
0056 moduleName = moduleName.replace("/",".")
0057 module = __import__(moduleName)
0058
0059 extendWithPrePostfix(process,sys.modules[moduleName],prefix,postfix,loadedProducersAndFilters)
0060
0061 def addToTask(loadedProducersAndFilters, module):
0062 if loadedProducersAndFilters:
0063 if isinstance(module, cms.EDProducer) or isinstance(module, cms.EDFilter):
0064 loadedProducersAndFilters.add(module)
0065
0066 def extendWithPrePostfix(process,other,prefix,postfix,loadedProducersAndFilters=None):
0067 """Look in other and find types which we can use"""
0068
0069
0070
0071 if loadedProducersAndFilters:
0072 task = getattr(process, loadedProducersAndFilters)
0073 if not isinstance(task, cms.Task):
0074 raise Exception("extendWithPrePostfix argument must be name of Task type object attached to the process or None")
0075 else:
0076 task = None
0077
0078 sequence = cms.Sequence()
0079 sequence._moduleLabels = []
0080 for name in dir(other):
0081
0082 if name.startswith('_'):
0083 continue
0084 item = getattr(other,name)
0085 if name == "source" or name == "looper" or name == "subProcess":
0086 continue
0087 elif isinstance(item,cms._ModuleSequenceType):
0088 continue
0089 elif isinstance(item,cms.Task):
0090 continue
0091 elif isinstance(item,cms.Schedule):
0092 continue
0093 elif isinstance(item,cms.VPSet) or isinstance(item,cms.PSet):
0094 continue
0095 elif isinstance(item,cms._Labelable):
0096 if not item.hasLabel_():
0097 item.setLabel(name)
0098 if prefix != '' or postfix != '':
0099 newModule = item.clone()
0100 if isinstance(item,cms.ESProducer):
0101 newName =name
0102 else:
0103 if 'TauDiscrimination' in name:
0104 process.__setattr__(name,item)
0105 addToTask(task, item)
0106 newName = prefix+name+postfix
0107 process.__setattr__(newName,newModule)
0108 addToTask(task, newModule)
0109 if isinstance(newModule, cms._Sequenceable) and not newName == name:
0110 sequence +=getattr(process,newName)
0111 sequence._moduleLabels.append(item.label())
0112 else:
0113 process.__setattr__(name,item)
0114 addToTask(task, item)
0115
0116 if prefix != '' or postfix != '':
0117 for label in sequence._moduleLabels:
0118 massSearchReplaceAnyInputTag(sequence, label, prefix+label+postfix,verbose=False,moduleLabelOnly=True)
0119
0120 def applyPostfix(process, label, postfix):
0121 result = None
0122 if hasattr(process, label+postfix):
0123 result = getattr(process, label + postfix)
0124 else:
0125 raise ValueError("Error in <applyPostfix>: No module of name = %s attached to process !!" % (label + postfix))
0126 return result
0127
0128 def removeIfInSequence(process, target, sequenceLabel, postfix=""):
0129 labels = __labelsInSequence(process, sequenceLabel, postfix, True)
0130 if target+postfix in labels:
0131 getattr(process, sequenceLabel+postfix).remove(
0132 getattr(process, target+postfix)
0133 )
0134
0135 def __labelsInSequence(process, sequenceLabel, postfix="", keepPostFix=False):
0136 position = -len(postfix)
0137 if keepPostFix:
0138 position = None
0139
0140 result = [ m.label()[:position] for m in listModules( getattr(process,sequenceLabel+postfix))]
0141 result.extend([ m.label()[:position] for m in listSequences( getattr(process,sequenceLabel+postfix))] )
0142 if postfix == "":
0143 result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
0144 result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))] )
0145 return result
0146
0147
0148 class GatherAllModulesVisitor(object):
0149 """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
0150 def __init__(self, gatheredInstance=cms._Module):
0151 self._modules = []
0152 self._gatheredInstance= gatheredInstance
0153 def enter(self,visitee):
0154 if isinstance(visitee,self._gatheredInstance):
0155 self._modules.append(visitee)
0156 def leave(self,visitee):
0157 pass
0158 def modules(self):
0159 return self._modules
0160
0161 class CloneSequenceVisitor(object):
0162 """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
0163 All modules and sequences are cloned and a postfix is added"""
0164 def __init__(self, process, label, postfix, removePostfix="", noClones = [], addToTask = False, verbose = False):
0165 self._process = process
0166 self._postfix = postfix
0167 self._removePostfix = removePostfix
0168 self._noClones = noClones
0169 self._addToTask = addToTask
0170 self._verbose = verbose
0171 self._moduleLabels = []
0172 self._clonedSequence = cms.Sequence()
0173 setattr(process, self._newLabel(label), self._clonedSequence)
0174 if addToTask:
0175 self._patAlgosToolsTask = getPatAlgosToolsTask(process)
0176
0177 def enter(self, visitee):
0178 if isinstance(visitee, cms._Module):
0179 label = visitee.label()
0180 newModule = None
0181 if label in self._noClones:
0182 newModule = getattr(self._process, label)
0183 elif label in self._moduleLabels:
0184 newModule = getattr(self._process, self._newLabel(label))
0185 else:
0186 self._moduleLabels.append(label)
0187 newModule = visitee.clone()
0188 setattr(self._process, self._newLabel(label), newModule)
0189 if self._addToTask:
0190 self._patAlgosToolsTask.add(getattr(self._process, self._newLabel(label)))
0191 self.__appendToTopSequence(newModule)
0192
0193 def leave(self, visitee):
0194 pass
0195
0196 def clonedSequence(self):
0197 for label in self._moduleLabels:
0198 massSearchReplaceAnyInputTag(self._clonedSequence, label, self._newLabel(label), moduleLabelOnly=True, verbose=self._verbose)
0199 self._moduleLabels = []
0200 return self._clonedSequence
0201
0202 def _newLabel(self, label):
0203 if self._removePostfix != "":
0204 if label[-len(self._removePostfix):] == self._removePostfix:
0205 label = label[0:-len(self._removePostfix)]
0206 else:
0207 raise Exception("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
0208 return label + self._postfix
0209
0210 def __appendToTopSequence(self, visitee):
0211 self._clonedSequence += visitee
0212
0213 def listModules(sequence):
0214 visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
0215 sequence.visit(visitor)
0216 return visitor.modules()
0217
0218 def listSequences(sequence):
0219 visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
0220 sequence.visit(visitor)
0221 return visitor.modules()
0222
0223 def jetCollectionString(prefix='', algo='', type=''):
0224 """
0225 ------------------------------------------------------------------
0226 return the string of the jet collection module depending on the
0227 input vaules. The default return value will be 'patAK5CaloJets'.
0228
0229 algo : indicating the algorithm type of the jet [expected are
0230 'AK5', 'IC5', 'SC7', ...]
0231 type : indicating the type of constituents of the jet [expec-
0232 ted are 'Calo', 'PFlow', 'JPT', ...]
0233 prefix : prefix indicating the type of pat collection module (ex-
0234 pected are '', 'selected', 'clean').
0235 ------------------------------------------------------------------
0236 """
0237 if(prefix==''):
0238 jetCollectionString ='pat'
0239 else:
0240 jetCollectionString =prefix
0241 jetCollectionString+='Pat'
0242 jetCollectionString+='Jets'
0243 jetCollectionString+=algo
0244 jetCollectionString+=type
0245 return jetCollectionString
0246
0247 def contains(sequence, moduleName):
0248 """
0249 ------------------------------------------------------------------
0250 return True if a module with name 'module' is contained in the
0251 sequence with name 'sequence' and False otherwise. This version
0252 is not so nice as it also returns True for any substr of the name
0253 of a contained module.
0254
0255 sequence : sequence [e.g. process.patDefaultSequence]
0256 module : module name as a string
0257 ------------------------------------------------------------------
0258 """
0259 return not sequence.__str__().find(moduleName)==-1
0260
0261
0262
0263 def cloneProcessingSnippet(process, sequence, postfix, removePostfix="", noClones = [], addToTask = False, verbose = False):
0264 """
0265 ------------------------------------------------------------------
0266 copy a sequence plus the modules and sequences therein
0267 both are renamed by getting a postfix
0268 input tags are automatically adjusted
0269 ------------------------------------------------------------------
0270 """
0271 result = sequence
0272 if not postfix == "":
0273 visitor = CloneSequenceVisitor(process, sequence.label(), postfix, removePostfix, noClones, addToTask, verbose)
0274 sequence.visit(visitor)
0275 result = visitor.clonedSequence()
0276 return result
0277
0278 def listDependencyChain(process, module, sources, verbose=False):
0279 """
0280 Walk up the dependencies of a module to find any that depend on any of the listed sources
0281 """
0282 def allDirectInputModules(moduleOrPSet,moduleName,attrName):
0283 ret = set()
0284 for name,value in moduleOrPSet.parameters_().items():
0285 type = value.pythonTypeName()
0286 if type == 'cms.PSet':
0287 ret.update(allDirectInputModules(value,moduleName,moduleName+"."+name))
0288 elif type == 'cms.VPSet':
0289 for (i,ps) in enumerate(value):
0290 ret.update(allDirectInputModules(ps,moduleName,"%s.%s[%d]"%(moduleName,name,i)))
0291 elif type == 'cms.VInputTag':
0292 inputs = [ MassSearchReplaceAnyInputTagVisitor.standardizeInputTagFmt(it) for it in value ]
0293 inputLabels = [ tag.moduleLabel for tag in inputs if tag.processName == '' or tag.processName == process.name_() ]
0294 ret.update(inputLabels)
0295 if verbose and inputLabels: print("%s depends on %s via %s" % (moduleName, inputLabels, attrName+"."+name))
0296 elif type.endswith('.InputTag'):
0297 if value.processName == '' or value.processName == process.name_():
0298 ret.add(value.moduleLabel)
0299 if verbose: print("%s depends on %s via %s" % (moduleName, value.moduleLabel, attrName+"."+name))
0300 ret.discard("")
0301 return ret
0302 def fillDirectDepGraphs(root,fwdepgraph,revdepgraph):
0303 if root.label_() in fwdepgraph: return
0304 deps = allDirectInputModules(root,root.label_(),root.label_())
0305 fwdepgraph[root.label_()] = []
0306 for d in deps:
0307 fwdepgraph[root.label_()].append(d)
0308 if d not in revdepgraph: revdepgraph[d] = []
0309 revdepgraph[d].append(root.label_())
0310 depmodule = getattr(process,d,None)
0311 if depmodule:
0312 fillDirectDepGraphs(depmodule,fwdepgraph,revdepgraph)
0313 return (fwdepgraph,revdepgraph)
0314 fwdepgraph, revdepgraph = fillDirectDepGraphs(module, {}, {})
0315 def flattenRevDeps(flatgraph, revdepgraph, tip):
0316 """Make a graph that for each module lists all the ones that depend on it, directly or indirectly"""
0317
0318 if tip in flatgraph: return
0319
0320 if tip not in revdepgraph: return
0321
0322 mydeps = set()
0323
0324 for d in revdepgraph[tip]:
0325
0326 flattenRevDeps(flatgraph, revdepgraph, d)
0327
0328 mydeps.add(d)
0329 if d in flatgraph:
0330 mydeps.update(flatgraph[d])
0331 flatgraph[tip] = mydeps
0332 flatdeps = {}
0333 allmodules = set()
0334 for s in sources:
0335 flattenRevDeps(flatdeps, revdepgraph, s)
0336 if s in flatdeps: allmodules.update(f for f in flatdeps[s])
0337 livemodules = [ a for a in allmodules if hasattr(process,a) ]
0338 if not livemodules: return None
0339 modulelist = [livemodules.pop()]
0340 for module in livemodules:
0341 for i,m in enumerate(modulelist):
0342 if module in flatdeps and m in flatdeps[module]:
0343 modulelist.insert(i, module)
0344 break
0345 if module not in modulelist:
0346 modulelist.append(module)
0347
0348 for i,m1 in enumerate(modulelist):
0349 for j,m2 in enumerate(modulelist):
0350 if j <= i: continue
0351 if m2 in flatdeps and m1 in flatdeps[m2]:
0352 raise RuntimeError("BAD ORDER %s BEFORE %s" % (m1,m2))
0353 modules = [ getattr(process,p) for p in modulelist ]
0354
0355 task = cms.Task()
0356 for mod in modules:
0357 task.add(mod)
0358 return task,cms.Sequence(task)
0359
0360 def addKeepStatement(process, oldKeep, newKeeps, verbose=False):
0361 """Add new keep statements to any PoolOutputModule of the process that has the old keep statements"""
0362 for name,out in process.outputModules.items():
0363 if out.type_() == 'PoolOutputModule' and hasattr(out, "outputCommands"):
0364 if oldKeep in out.outputCommands:
0365 out.outputCommands += newKeeps
0366 if verbose:
0367 print("Adding the following keep statements to output module %s: " % name)
0368 for k in newKeeps: print("\t'%s'," % k)
0369
0370
0371 if __name__=="__main__":
0372 import unittest
0373 def _lineDiff(newString, oldString):
0374 newString = ( x for x in newString.split('\n') if len(x) > 0)
0375 oldString = [ x for x in oldString.split('\n') if len(x) > 0]
0376 diff = []
0377 oldStringLine = 0
0378 for l in newString:
0379 if oldStringLine >= len(oldString):
0380 diff.append(l)
0381 continue
0382 if l == oldString[oldStringLine]:
0383 oldStringLine +=1
0384 continue
0385 diff.append(l)
0386 return "\n".join( diff )
0387
0388 class TestModuleCommand(unittest.TestCase):
0389 def setUp(self):
0390 """Nothing to do """
0391 pass
0392 def testCloning(self):
0393 p = cms.Process("test")
0394 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0395 p.b = cms.EDProducer("b", src=cms.InputTag("a"))
0396 p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
0397 p.s = cms.Sequence(p.a*p.b*p.c *p.a)
0398 cloneProcessingSnippet(p, p.s, "New", addToTask = True)
0399 self.assertEqual(_lineDiff(p.dumpPython(), cms.Process("test").dumpPython()),
0400 """process.a = cms.EDProducer("a",
0401 src = cms.InputTag("gen")
0402 )
0403 process.aNew = cms.EDProducer("a",
0404 src = cms.InputTag("gen")
0405 )
0406 process.b = cms.EDProducer("b",
0407 src = cms.InputTag("a")
0408 )
0409 process.bNew = cms.EDProducer("b",
0410 src = cms.InputTag("aNew")
0411 )
0412 process.c = cms.EDProducer("c",
0413 src = cms.InputTag("b","instance")
0414 )
0415 process.cNew = cms.EDProducer("c",
0416 src = cms.InputTag("bNew","instance")
0417 )
0418 process.patAlgosToolsTask = cms.Task(process.aNew, process.bNew, process.cNew)
0419 process.s = cms.Sequence(process.a+process.b+process.c+process.a)
0420 process.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew+process.aNew)""")
0421 def testContains(self):
0422 p = cms.Process("test")
0423 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0424 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0425 p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
0426 p.s1 = cms.Sequence(p.a*p.b*p.c)
0427 p.s2 = cms.Sequence(p.b*p.c)
0428 self.assertTrue( contains(p.s1, "a") )
0429 self.assertTrue( not contains(p.s2, "a") )
0430 def testJetCollectionString(self):
0431 self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patJetsFooBar')
0432 self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatJetsFooBar')
0433 def testListModules(self):
0434 p = cms.Process("test")
0435 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0436 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0437 p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
0438 p.s = cms.Sequence(p.a*p.b*p.c)
0439 self.assertEqual([p.a,p.b,p.c], listModules(p.s))
0440
0441 unittest.main()
0442
0443 class CloneTaskVisitor(object):
0444 """Visitor that travels within a cms.Task, and returns a cloned version of the Task.
0445 All modules are cloned and a postfix is added"""
0446 def __init__(self, process, label, postfix, removePostfix="", noClones = [], verbose = False):
0447 self._process = process
0448 self._postfix = postfix
0449 self._removePostfix = removePostfix
0450 self._noClones = noClones
0451 self._verbose = verbose
0452 self._moduleLabels = []
0453 self._clonedTask = cms.Task()
0454 setattr(process, self._newLabel(label), self._clonedTask)
0455
0456 def enter(self, visitee):
0457 if isinstance(visitee, cms._Module):
0458 label = visitee.label()
0459 newModule = None
0460 if label in self._noClones:
0461 newModule = getattr(self._process, label)
0462 elif label in self._moduleLabels:
0463 newModule = getattr(self._process, self._newLabel(label))
0464 else:
0465 self._moduleLabels.append(label)
0466 newModule = visitee.clone()
0467 setattr(self._process, self._newLabel(label), newModule)
0468 self.__appendToTopTask(newModule)
0469
0470 def leave(self, visitee):
0471 pass
0472
0473 def clonedTask(self):
0474 for label in self._moduleLabels:
0475 massSearchReplaceAnyInputTag(self._clonedTask, label, self._newLabel(label), moduleLabelOnly=True, verbose=self._verbose)
0476 self._moduleLabels = []
0477 return self._clonedTask
0478
0479 def _newLabel(self, label):
0480 if self._removePostfix != "":
0481 if label[-len(self._removePostfix):] == self._removePostfix:
0482 label = label[0:-len(self._removePostfix)]
0483 else:
0484 raise Exception("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
0485 return label + self._postfix
0486
0487 def __appendToTopTask(self, visitee):
0488 self._clonedTask.add(visitee)
0489
0490 def cloneProcessingSnippetTask(process, task, postfix, removePostfix="", noClones = [], verbose = False):
0491 """
0492 ------------------------------------------------------------------
0493 copy a task plus the modules therein (including modules in subtasks)
0494 both are renamed by getting a postfix
0495 input tags are automatically adjusted
0496 ------------------------------------------------------------------
0497 """
0498 result = task
0499 if not postfix == "":
0500 visitor = CloneTaskVisitor(process, task.label(), postfix, removePostfix, noClones, verbose)
0501 task.visit(visitor)
0502 result = visitor.clonedTask()
0503 return result