File indexing completed on 2024-11-27 03:17:59
0001 import FWCore.ParameterSet.Config as cms
0002 import sys
0003
0004
0005
0006
0007
0008
0009 from FWCore.ParameterSet.MassReplace import massSearchReplaceAnyInputTag, MassSearchReplaceAnyInputTagVisitor
0010 from FWCore.ParameterSet.MassReplace import massSearchReplaceParam, MassSearchParamVisitor, MassSearchReplaceParamVisitor
0011
0012 def getPatAlgosToolsTask(process):
0013 taskName = "patAlgosToolsTask"
0014 if hasattr(process, taskName):
0015 task = getattr(process, taskName)
0016 if not isinstance(task, cms.Task):
0017 raise Exception("patAlgosToolsTask does not have type Task")
0018 else:
0019 setattr(process, taskName, cms.Task())
0020 task = getattr(process, taskName)
0021 return task
0022
0023 def associatePatAlgosToolsTask(process):
0024 task = getPatAlgosToolsTask(process)
0025 process.schedule.associate(task)
0026
0027 def addToProcessAndTask(label, module, process, task):
0028 setattr(process, label, module)
0029 task.add(getattr(process, label))
0030
0031 def addTaskToProcess(process, label, task):
0032 if not hasattr(process, label):
0033 setattr(process, label, task)
0034 else:
0035 getattr(process, label).add(task)
0036
0037 def addESProducers(process,config):
0038 config = config.replace("/",".")
0039
0040
0041 module = __import__(config)
0042 for name in dir(sys.modules[config]):
0043 item = getattr(sys.modules[config],name)
0044 if isinstance(item,cms._Labelable) and not isinstance(item,cms._ModuleSequenceType) and not name.startswith('_') and not (name == "source" or name == "looper" or name == "subProcess") and not isinstance(item, cms.PSet):
0045 if 'ESProducer' in item.type_():
0046 setattr(process,name,item)
0047
0048 def loadWithPrefix(process,moduleName,prefix='',loadedProducersAndFilters=None):
0049 loadWithPrePostfix(process,moduleName,prefix,'',loadedProducersAndFilters)
0050
0051 def loadWithPostfix(process,moduleName,postfix='',loadedProducersAndFilters=None):
0052 loadWithPrePostfix(process,moduleName,'',postfix,loadedProducersAndFilters)
0053
0054 def loadWithPrePostfix(process,moduleName,prefix='',postfix='',loadedProducersAndFilters=None):
0055 moduleName = moduleName.replace("/",".")
0056 module = __import__(moduleName)
0057
0058 extendWithPrePostfix(process,sys.modules[moduleName],prefix,postfix,loadedProducersAndFilters)
0059
0060 def addToTask(loadedProducersAndFilters, module):
0061 if loadedProducersAndFilters:
0062 if isinstance(module, cms.EDProducer) or isinstance(module, cms.EDFilter):
0063 loadedProducersAndFilters.add(module)
0064
0065 def extendWithPrePostfix(process,other,prefix,postfix,loadedProducersAndFilters=None):
0066 """Look in other and find types which we can use"""
0067
0068
0069
0070 if loadedProducersAndFilters:
0071 task = getattr(process, loadedProducersAndFilters)
0072 if not isinstance(task, cms.Task):
0073 raise Exception("extendWithPrePostfix argument must be name of Task type object attached to the process or None")
0074 else:
0075 task = None
0076
0077 sequence = cms.Sequence()
0078 sequence._moduleLabels = []
0079 for name in dir(other):
0080
0081 if name.startswith('_'):
0082 continue
0083 item = getattr(other,name)
0084 if name == "source" or name == "looper" or name == "subProcess":
0085 continue
0086 elif isinstance(item,cms._ModuleSequenceType):
0087 continue
0088 elif isinstance(item,cms.Task):
0089 continue
0090 elif isinstance(item,cms.Schedule):
0091 continue
0092 elif isinstance(item,cms.VPSet) or isinstance(item,cms.PSet):
0093 continue
0094 elif isinstance(item,cms._Labelable):
0095 if not item.hasLabel_():
0096 item.setLabel(name)
0097 if prefix != '' or postfix != '':
0098 newModule = item.clone()
0099 if isinstance(item,cms.ESProducer):
0100 newName =name
0101 else:
0102 if 'TauDiscrimination' in name:
0103 process.__setattr__(name,item)
0104 addToTask(task, item)
0105 newName = prefix+name+postfix
0106 process.__setattr__(newName,newModule)
0107 addToTask(task, newModule)
0108 if isinstance(newModule, cms._Sequenceable) and not newName == name:
0109 sequence +=getattr(process,newName)
0110 sequence._moduleLabels.append(item.label())
0111 else:
0112 process.__setattr__(name,item)
0113 addToTask(task, item)
0114
0115 if prefix != '' or postfix != '':
0116 for label in sequence._moduleLabels:
0117 massSearchReplaceAnyInputTag(sequence, label, prefix+label+postfix,verbose=False,moduleLabelOnly=True)
0118
0119 def applyPostfix(process, label, postfix):
0120 result = None
0121 if hasattr(process, label+postfix):
0122 result = getattr(process, label + postfix)
0123 else:
0124 raise ValueError("Error in <applyPostfix>: No module of name = %s attached to process !!" % (label + postfix))
0125 return result
0126
0127 def removeIfInSequence(process, target, sequenceLabel, postfix=""):
0128 labels = __labelsInSequence(process, sequenceLabel, postfix, True)
0129 if target+postfix in labels:
0130 getattr(process, sequenceLabel+postfix).remove(
0131 getattr(process, target+postfix)
0132 )
0133
0134 def __labelsInSequence(process, sequenceLabel, postfix="", keepPostFix=False):
0135 position = -len(postfix)
0136 if keepPostFix:
0137 position = None
0138
0139 result = [ m.label()[:position] for m in listModules( getattr(process,sequenceLabel+postfix))]
0140 result.extend([ m.label()[:position] for m in listSequences( getattr(process,sequenceLabel+postfix))] )
0141 if postfix == "":
0142 result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
0143 result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))] )
0144 return result
0145
0146
0147 class GatherAllModulesVisitor(object):
0148 """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
0149 def __init__(self, gatheredInstance=cms._Module):
0150 self._modules = []
0151 self._gatheredInstance= gatheredInstance
0152 def enter(self,visitee):
0153 if isinstance(visitee,self._gatheredInstance):
0154 self._modules.append(visitee)
0155 def leave(self,visitee):
0156 pass
0157 def modules(self):
0158 return self._modules
0159
0160 class CloneSequenceVisitor(object):
0161 """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
0162 All modules and sequences are cloned and a postfix is added"""
0163 def __init__(self, process, label, postfix, removePostfix="", noClones = [], addToTask = False, verbose = False):
0164 self._process = process
0165 self._postfix = postfix
0166 self._removePostfix = removePostfix
0167 self._noClones = noClones
0168 self._addToTask = addToTask
0169 self._verbose = verbose
0170 self._moduleLabels = []
0171 self._clonedSequence = cms.Sequence()
0172 setattr(process, self._newLabel(label), self._clonedSequence)
0173 if addToTask:
0174 self._patAlgosToolsTask = getPatAlgosToolsTask(process)
0175
0176 def enter(self, visitee):
0177 if isinstance(visitee, cms._Module):
0178 label = visitee.label()
0179 newModule = None
0180 if label in self._noClones:
0181 newModule = getattr(self._process, label)
0182 elif label in self._moduleLabels:
0183 newModule = getattr(self._process, self._newLabel(label))
0184 else:
0185 self._moduleLabels.append(label)
0186 newModule = visitee.clone()
0187 setattr(self._process, self._newLabel(label), newModule)
0188 if self._addToTask:
0189 self._patAlgosToolsTask.add(getattr(self._process, self._newLabel(label)))
0190 self.__appendToTopSequence(newModule)
0191
0192 def leave(self, visitee):
0193 pass
0194
0195 def clonedSequence(self):
0196 for label in self._moduleLabels:
0197 massSearchReplaceAnyInputTag(self._clonedSequence, label, self._newLabel(label), moduleLabelOnly=True, verbose=self._verbose)
0198 self._moduleLabels = []
0199 return self._clonedSequence
0200
0201 def _newLabel(self, label):
0202 if self._removePostfix != "":
0203 if label[-len(self._removePostfix):] == self._removePostfix:
0204 label = label[0:-len(self._removePostfix)]
0205 else:
0206 raise Exception("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
0207 return label + self._postfix
0208
0209 def __appendToTopSequence(self, visitee):
0210 self._clonedSequence += visitee
0211
0212 def listModules(sequence):
0213 visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
0214 sequence.visit(visitor)
0215 return visitor.modules()
0216
0217 def listSequences(sequence):
0218 visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
0219 sequence.visit(visitor)
0220 return visitor.modules()
0221
0222 def jetCollectionString(prefix='', algo='', type=''):
0223 """
0224 ------------------------------------------------------------------
0225 return the string of the jet collection module depending on the
0226 input vaules. The default return value will be 'patAK5CaloJets'.
0227
0228 algo : indicating the algorithm type of the jet [expected are
0229 'AK5', 'IC5', 'SC7', ...]
0230 type : indicating the type of constituents of the jet [expec-
0231 ted are 'Calo', 'PFlow', 'JPT', ...]
0232 prefix : prefix indicating the type of pat collection module (ex-
0233 pected are '', 'selected', 'clean').
0234 ------------------------------------------------------------------
0235 """
0236 if(prefix==''):
0237 jetCollectionString ='pat'
0238 else:
0239 jetCollectionString =prefix
0240 jetCollectionString+='Pat'
0241 jetCollectionString+='Jets'
0242 jetCollectionString+=algo
0243 jetCollectionString+=type
0244 return jetCollectionString
0245
0246 def contains(sequence, moduleName):
0247 """
0248 ------------------------------------------------------------------
0249 return True if a module with name 'module' is contained in the
0250 sequence with name 'sequence' and False otherwise. This version
0251 is not so nice as it also returns True for any substr of the name
0252 of a contained module.
0253
0254 sequence : sequence [e.g. process.patDefaultSequence]
0255 module : module name as a string
0256 ------------------------------------------------------------------
0257 """
0258 return not sequence.__str__().find(moduleName)==-1
0259
0260
0261
0262 def cloneProcessingSnippet(process, sequence, postfix, removePostfix="", noClones = [], addToTask = False, verbose = False):
0263 """
0264 ------------------------------------------------------------------
0265 copy a sequence plus the modules and sequences therein
0266 both are renamed by getting a postfix
0267 input tags are automatically adjusted
0268 ------------------------------------------------------------------
0269 """
0270 result = sequence
0271 if not postfix == "":
0272 visitor = CloneSequenceVisitor(process, sequence.label(), postfix, removePostfix, noClones, addToTask, verbose)
0273 sequence.visit(visitor)
0274 result = visitor.clonedSequence()
0275 return result
0276
0277 def listDependencyChain(process, module, sources, verbose=False):
0278 """
0279 Walk up the dependencies of a module to find any that depend on any of the listed sources
0280 """
0281 def allDirectInputModules(moduleOrPSet,moduleName,attrName):
0282 ret = set()
0283 for name,value in moduleOrPSet.parameters_().items():
0284 type = value.pythonTypeName()
0285 if type == 'cms.PSet':
0286 ret.update(allDirectInputModules(value,moduleName,moduleName+"."+name))
0287 elif type == 'cms.VPSet':
0288 for (i,ps) in enumerate(value):
0289 ret.update(allDirectInputModules(ps,moduleName,"%s.%s[%d]"%(moduleName,name,i)))
0290 elif type == 'cms.VInputTag':
0291 inputs = [ MassSearchReplaceAnyInputTagVisitor.standardizeInputTagFmt(it) for it in value ]
0292 inputLabels = [ tag.moduleLabel for tag in inputs if tag.processName == '' or tag.processName == process.name_() ]
0293 ret.update(inputLabels)
0294 if verbose and inputLabels: print("%s depends on %s via %s" % (moduleName, inputLabels, attrName+"."+name))
0295 elif type.endswith('.InputTag'):
0296 if value.processName == '' or value.processName == process.name_():
0297 ret.add(value.moduleLabel)
0298 if verbose: print("%s depends on %s via %s" % (moduleName, value.moduleLabel, attrName+"."+name))
0299 ret.discard("")
0300 return ret
0301 def fillDirectDepGraphs(root,fwdepgraph,revdepgraph):
0302 if root.label_() in fwdepgraph: return
0303 deps = allDirectInputModules(root,root.label_(),root.label_())
0304 fwdepgraph[root.label_()] = []
0305 for d in deps:
0306 fwdepgraph[root.label_()].append(d)
0307 if d not in revdepgraph: revdepgraph[d] = []
0308 revdepgraph[d].append(root.label_())
0309 depmodule = getattr(process,d,None)
0310 if depmodule:
0311 fillDirectDepGraphs(depmodule,fwdepgraph,revdepgraph)
0312 return (fwdepgraph,revdepgraph)
0313 fwdepgraph, revdepgraph = fillDirectDepGraphs(module, {}, {})
0314 def flattenRevDeps(flatgraph, revdepgraph, tip):
0315 """Make a graph that for each module lists all the ones that depend on it, directly or indirectly"""
0316
0317 if tip in flatgraph: return
0318
0319 if tip not in revdepgraph: return
0320
0321 mydeps = set()
0322
0323 for d in revdepgraph[tip]:
0324
0325 flattenRevDeps(flatgraph, revdepgraph, d)
0326
0327 mydeps.add(d)
0328 if d in flatgraph:
0329 mydeps.update(flatgraph[d])
0330 flatgraph[tip] = mydeps
0331 flatdeps = {}
0332 allmodules = set()
0333 for s in sources:
0334 flattenRevDeps(flatdeps, revdepgraph, s)
0335 if s in flatdeps: allmodules.update(f for f in flatdeps[s])
0336 livemodules = [ a for a in allmodules if hasattr(process,a) ]
0337 if not livemodules: return None
0338 modulelist = [livemodules.pop()]
0339 for module in livemodules:
0340 for i,m in enumerate(modulelist):
0341 if module in flatdeps and m in flatdeps[module]:
0342 modulelist.insert(i, module)
0343 break
0344 if module not in modulelist:
0345 modulelist.append(module)
0346
0347 for i,m1 in enumerate(modulelist):
0348 for j,m2 in enumerate(modulelist):
0349 if j <= i: continue
0350 if m2 in flatdeps and m1 in flatdeps[m2]:
0351 raise RuntimeError("BAD ORDER %s BEFORE %s" % (m1,m2))
0352 modules = [ getattr(process,p) for p in modulelist ]
0353
0354 task = cms.Task()
0355 for mod in modules:
0356 task.add(mod)
0357 return task,cms.Sequence(task)
0358
0359 def addKeepStatement(process, oldKeep, newKeeps, verbose=False):
0360 """Add new keep statements to any PoolOutputModule of the process that has the old keep statements"""
0361 for name,out in process.outputModules.items():
0362 if out.type_() == 'PoolOutputModule' and hasattr(out, "outputCommands"):
0363 if oldKeep in out.outputCommands:
0364 out.outputCommands += newKeeps
0365 if verbose:
0366 print("Adding the following keep statements to output module %s: " % name)
0367 for k in newKeeps: print("\t'%s'," % k)
0368
0369
0370 if __name__=="__main__":
0371 import unittest
0372 def _lineDiff(newString, oldString):
0373 newString = ( x for x in newString.split('\n') if len(x) > 0)
0374 oldString = [ x for x in oldString.split('\n') if len(x) > 0]
0375 diff = []
0376 oldStringLine = 0
0377 for l in newString:
0378 if oldStringLine >= len(oldString):
0379 diff.append(l)
0380 continue
0381 if l == oldString[oldStringLine]:
0382 oldStringLine +=1
0383 continue
0384 diff.append(l)
0385 return "\n".join( diff )
0386
0387 class TestModuleCommand(unittest.TestCase):
0388 def setUp(self):
0389 """Nothing to do """
0390 pass
0391 def testCloning(self):
0392 p = cms.Process("test")
0393 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0394 p.b = cms.EDProducer("b", src=cms.InputTag("a"))
0395 p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
0396 p.s = cms.Sequence(p.a*p.b*p.c *p.a)
0397 cloneProcessingSnippet(p, p.s, "New", addToTask = True)
0398 self.assertEqual(_lineDiff(p.dumpPython(), cms.Process("test").dumpPython()),
0399 """process.a = cms.EDProducer("a",
0400 src = cms.InputTag("gen")
0401 )
0402 process.aNew = cms.EDProducer("a",
0403 src = cms.InputTag("gen")
0404 )
0405 process.b = cms.EDProducer("b",
0406 src = cms.InputTag("a")
0407 )
0408 process.bNew = cms.EDProducer("b",
0409 src = cms.InputTag("aNew")
0410 )
0411 process.c = cms.EDProducer("c",
0412 src = cms.InputTag("b","instance")
0413 )
0414 process.cNew = cms.EDProducer("c",
0415 src = cms.InputTag("bNew","instance")
0416 )
0417 process.patAlgosToolsTask = cms.Task(process.aNew, process.bNew, process.cNew)
0418 process.s = cms.Sequence(process.a+process.b+process.c+process.a)
0419 process.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew+process.aNew)""")
0420 def testContains(self):
0421 p = cms.Process("test")
0422 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0423 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0424 p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
0425 p.s1 = cms.Sequence(p.a*p.b*p.c)
0426 p.s2 = cms.Sequence(p.b*p.c)
0427 self.assertTrue( contains(p.s1, "a") )
0428 self.assertTrue( not contains(p.s2, "a") )
0429 def testJetCollectionString(self):
0430 self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patJetsFooBar')
0431 self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatJetsFooBar')
0432 def testListModules(self):
0433 p = cms.Process("test")
0434 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0435 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0436 p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
0437 p.s = cms.Sequence(p.a*p.b*p.c)
0438 self.assertEqual([p.a,p.b,p.c], listModules(p.s))
0439
0440 unittest.main()
0441
0442 class CloneTaskVisitor(object):
0443 """Visitor that travels within a cms.Task, and returns a cloned version of the Task.
0444 All modules are cloned and a postfix is added"""
0445 def __init__(self, process, label, postfix, removePostfix="", noClones = [], verbose = False):
0446 self._process = process
0447 self._postfix = postfix
0448 self._removePostfix = removePostfix
0449 self._noClones = noClones
0450 self._verbose = verbose
0451 self._moduleLabels = []
0452 self._clonedTask = cms.Task()
0453 setattr(process, self._newLabel(label), self._clonedTask)
0454
0455 def enter(self, visitee):
0456 if isinstance(visitee, cms._Module):
0457 label = visitee.label()
0458 newModule = None
0459 if label in self._noClones:
0460 newModule = getattr(self._process, label)
0461 elif label in self._moduleLabels:
0462 newModule = getattr(self._process, self._newLabel(label))
0463 else:
0464 self._moduleLabels.append(label)
0465 newModule = visitee.clone()
0466 setattr(self._process, self._newLabel(label), newModule)
0467 self.__appendToTopTask(newModule)
0468
0469 def leave(self, visitee):
0470 pass
0471
0472 def clonedTask(self):
0473 for label in self._moduleLabels:
0474 massSearchReplaceAnyInputTag(self._clonedTask, label, self._newLabel(label), moduleLabelOnly=True, verbose=self._verbose)
0475 self._moduleLabels = []
0476 return self._clonedTask
0477
0478 def _newLabel(self, label):
0479 if self._removePostfix != "":
0480 if label[-len(self._removePostfix):] == self._removePostfix:
0481 label = label[0:-len(self._removePostfix)]
0482 else:
0483 raise Exception("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
0484 return label + self._postfix
0485
0486 def __appendToTopTask(self, visitee):
0487 self._clonedTask.add(visitee)
0488
0489 def cloneProcessingSnippetTask(process, task, postfix, removePostfix="", noClones = [], verbose = False):
0490 """
0491 ------------------------------------------------------------------
0492 copy a task plus the modules therein (including modules in subtasks)
0493 both are renamed by getting a postfix
0494 input tags are automatically adjusted
0495 ------------------------------------------------------------------
0496 """
0497 result = task
0498 if not postfix == "":
0499 visitor = CloneTaskVisitor(process, task.label(), postfix, removePostfix, noClones, verbose)
0500 task.visit(visitor)
0501 result = visitor.clonedTask()
0502 return result