Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-11-27 03:17:59

0001 import FWCore.ParameterSet.Config as cms
0002 import sys
0003 
0004 ## Helpers to perform some technically boring tasks like looking for all modules with a given parameter
0005 ## and replacing that to a given value
0006 
0007 # Next two lines are for backward compatibility, the imported functions and
0008 # classes used to be defined in this file.
0009 from FWCore.ParameterSet.MassReplace import massSearchReplaceAnyInputTag, MassSearchReplaceAnyInputTagVisitor
0010 from FWCore.ParameterSet.MassReplace import massSearchReplaceParam, MassSearchParamVisitor, MassSearchReplaceParamVisitor
0011 
0012 def getPatAlgosToolsTask(process):
0013     taskName = "patAlgosToolsTask"
0014     if hasattr(process, taskName):
0015         task = getattr(process, taskName)
0016         if not isinstance(task, cms.Task):
0017             raise Exception("patAlgosToolsTask does not have type Task")
0018     else:
0019         setattr(process, taskName, cms.Task())
0020         task = getattr(process, taskName)
0021     return task
0022 
0023 def associatePatAlgosToolsTask(process):
0024     task = getPatAlgosToolsTask(process)
0025     process.schedule.associate(task)
0026 
0027 def addToProcessAndTask(label, module, process, task):
0028     setattr(process, label, module)
0029     task.add(getattr(process, label))
0030 
0031 def addTaskToProcess(process, label, task):
0032     if not hasattr(process, label):
0033         setattr(process, label, task)
0034     else:
0035         getattr(process, label).add(task)
0036 
0037 def addESProducers(process,config):
0038     config = config.replace("/",".")
0039     #import RecoBTag.Configuration.RecoBTag_cff as btag
0040     #print btag
0041     module = __import__(config)
0042     for name in dir(sys.modules[config]):
0043         item = getattr(sys.modules[config],name)
0044         if isinstance(item,cms._Labelable) and not isinstance(item,cms._ModuleSequenceType) and not name.startswith('_') and not (name == "source" or name == "looper" or name == "subProcess") and not isinstance(item, cms.PSet):
0045             if 'ESProducer' in item.type_():
0046                 setattr(process,name,item)
0047 
0048 def loadWithPrefix(process,moduleName,prefix='',loadedProducersAndFilters=None):
0049     loadWithPrePostfix(process,moduleName,prefix,'',loadedProducersAndFilters)
0050 
0051 def loadWithPostfix(process,moduleName,postfix='',loadedProducersAndFilters=None):
0052     loadWithPrePostfix(process,moduleName,'',postfix,loadedProducersAndFilters)
0053 
0054 def loadWithPrePostfix(process,moduleName,prefix='',postfix='',loadedProducersAndFilters=None):
0055     moduleName = moduleName.replace("/",".")
0056     module = __import__(moduleName)
0057     #print module.PatAlgos.patSequences_cff.patDefaultSequence
0058     extendWithPrePostfix(process,sys.modules[moduleName],prefix,postfix,loadedProducersAndFilters)
0059 
0060 def addToTask(loadedProducersAndFilters, module):
0061     if loadedProducersAndFilters:
0062         if isinstance(module, cms.EDProducer) or isinstance(module, cms.EDFilter):
0063             loadedProducersAndFilters.add(module)
0064 
0065 def extendWithPrePostfix(process,other,prefix,postfix,loadedProducersAndFilters=None):
0066     """Look in other and find types which we can use"""
0067     # enable explicit check to avoid overwriting of existing objects
0068     #__dict__['_Process__InExtendCall'] = True
0069 
0070     if loadedProducersAndFilters:
0071         task = getattr(process, loadedProducersAndFilters)
0072         if not isinstance(task, cms.Task):
0073             raise Exception("extendWithPrePostfix argument must be name of Task type object attached to the process or None")
0074     else:
0075         task = None
0076 
0077     sequence = cms.Sequence()
0078     sequence._moduleLabels = []
0079     for name in dir(other):
0080         #'from XX import *' ignores these, and so should we.
0081         if name.startswith('_'):
0082             continue
0083         item = getattr(other,name)
0084         if name == "source" or name == "looper" or name == "subProcess":
0085             continue
0086         elif isinstance(item,cms._ModuleSequenceType):
0087             continue
0088         elif isinstance(item,cms.Task):
0089             continue
0090         elif isinstance(item,cms.Schedule):
0091             continue
0092         elif isinstance(item,cms.VPSet) or isinstance(item,cms.PSet):
0093             continue
0094         elif isinstance(item,cms._Labelable):
0095             if not item.hasLabel_():
0096                 item.setLabel(name)
0097             if prefix != '' or postfix != '':
0098                 newModule = item.clone()
0099                 if isinstance(item,cms.ESProducer):
0100                     newName =name
0101                 else:
0102                     if 'TauDiscrimination' in name:
0103                         process.__setattr__(name,item)
0104                         addToTask(task, item)
0105                     newName = prefix+name+postfix
0106                 process.__setattr__(newName,newModule)
0107                 addToTask(task, newModule)
0108                 if isinstance(newModule, cms._Sequenceable) and not newName == name:
0109                     sequence +=getattr(process,newName)
0110                     sequence._moduleLabels.append(item.label())
0111             else:
0112                 process.__setattr__(name,item)
0113                 addToTask(task, item)
0114 
0115     if prefix != '' or postfix != '':
0116         for label in sequence._moduleLabels:
0117             massSearchReplaceAnyInputTag(sequence, label, prefix+label+postfix,verbose=False,moduleLabelOnly=True)
0118 
0119 def applyPostfix(process, label, postfix):
0120     result = None
0121     if hasattr(process, label+postfix):
0122         result = getattr(process, label + postfix)
0123     else:
0124         raise ValueError("Error in <applyPostfix>: No module of name = %s attached to process !!" % (label + postfix))
0125     return result
0126 
0127 def removeIfInSequence(process, target,  sequenceLabel, postfix=""):
0128     labels = __labelsInSequence(process, sequenceLabel, postfix, True)
0129     if target+postfix in labels:
0130         getattr(process, sequenceLabel+postfix).remove(
0131             getattr(process, target+postfix)
0132             )
0133 
0134 def __labelsInSequence(process, sequenceLabel, postfix="", keepPostFix=False):
0135     position = -len(postfix)
0136     if keepPostFix: 
0137         position = None
0138 
0139     result = [ m.label()[:position] for m in listModules( getattr(process,sequenceLabel+postfix))]
0140     result.extend([ m.label()[:position] for m in listSequences( getattr(process,sequenceLabel+postfix))]  )
0141     if postfix == "":
0142         result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
0143         result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))]  )
0144     return result
0145 
0146 #FIXME name is not generic enough now
0147 class GatherAllModulesVisitor(object):
0148     """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
0149     def __init__(self, gatheredInstance=cms._Module):
0150         self._modules = []
0151         self._gatheredInstance= gatheredInstance
0152     def enter(self,visitee):
0153         if isinstance(visitee,self._gatheredInstance):
0154             self._modules.append(visitee)
0155     def leave(self,visitee):
0156         pass
0157     def modules(self):
0158         return self._modules
0159 
0160 class CloneSequenceVisitor(object):
0161     """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
0162     All modules and sequences are cloned and a postfix is added"""
0163     def __init__(self, process, label, postfix, removePostfix="", noClones = [], addToTask = False, verbose = False):
0164         self._process = process
0165         self._postfix = postfix
0166         self._removePostfix = removePostfix
0167         self._noClones = noClones
0168         self._addToTask = addToTask
0169         self._verbose = verbose
0170         self._moduleLabels = []
0171         self._clonedSequence = cms.Sequence()
0172         setattr(process, self._newLabel(label), self._clonedSequence)
0173         if addToTask:
0174             self._patAlgosToolsTask = getPatAlgosToolsTask(process)
0175 
0176     def enter(self, visitee):
0177         if isinstance(visitee, cms._Module):
0178             label = visitee.label()
0179             newModule = None
0180             if label in self._noClones: #keep unchanged
0181                 newModule = getattr(self._process, label)
0182             elif label in self._moduleLabels: # has the module already been cloned ?
0183                 newModule = getattr(self._process, self._newLabel(label))
0184             else:
0185                 self._moduleLabels.append(label)
0186                 newModule = visitee.clone()
0187                 setattr(self._process, self._newLabel(label), newModule)
0188                 if self._addToTask:
0189                     self._patAlgosToolsTask.add(getattr(self._process, self._newLabel(label)))
0190             self.__appendToTopSequence(newModule)
0191 
0192     def leave(self, visitee):
0193         pass
0194 
0195     def clonedSequence(self):
0196         for label in self._moduleLabels:
0197             massSearchReplaceAnyInputTag(self._clonedSequence, label, self._newLabel(label), moduleLabelOnly=True, verbose=self._verbose)
0198         self._moduleLabels = [] # prevent the InputTag replacement next time the 'clonedSequence' function is called.
0199         return self._clonedSequence
0200 
0201     def _newLabel(self, label):
0202         if self._removePostfix != "":
0203             if label[-len(self._removePostfix):] == self._removePostfix:
0204                 label = label[0:-len(self._removePostfix)]
0205             else:
0206                 raise Exception("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
0207         return label + self._postfix
0208 
0209     def __appendToTopSequence(self, visitee):
0210         self._clonedSequence += visitee
0211 
0212 def listModules(sequence):
0213     visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
0214     sequence.visit(visitor)
0215     return visitor.modules()
0216 
0217 def listSequences(sequence):
0218     visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
0219     sequence.visit(visitor)
0220     return visitor.modules()
0221 
0222 def jetCollectionString(prefix='', algo='', type=''):
0223     """
0224     ------------------------------------------------------------------
0225     return the string of the jet collection module depending on the
0226     input vaules. The default return value will be 'patAK5CaloJets'.
0227 
0228     algo   : indicating the algorithm type of the jet [expected are
0229              'AK5', 'IC5', 'SC7', ...]
0230     type   : indicating the type of constituents of the jet [expec-
0231              ted are 'Calo', 'PFlow', 'JPT', ...]
0232     prefix : prefix indicating the type of pat collection module (ex-
0233              pected are '', 'selected', 'clean').
0234     ------------------------------------------------------------------
0235     """
0236     if(prefix==''):
0237         jetCollectionString ='pat'
0238     else:
0239         jetCollectionString =prefix
0240         jetCollectionString+='Pat'
0241     jetCollectionString+='Jets'
0242     jetCollectionString+=algo
0243     jetCollectionString+=type
0244     return jetCollectionString
0245 
0246 def contains(sequence, moduleName):
0247     """
0248     ------------------------------------------------------------------
0249     return True if a module with name 'module' is contained in the
0250     sequence with name 'sequence' and False otherwise. This version
0251     is not so nice as it also returns True for any substr of the name
0252     of a contained module.
0253 
0254     sequence : sequence [e.g. process.patDefaultSequence]
0255     module   : module name as a string
0256     ------------------------------------------------------------------
0257     """
0258     return not sequence.__str__().find(moduleName)==-1
0259 
0260 
0261 
0262 def cloneProcessingSnippet(process, sequence, postfix, removePostfix="", noClones = [], addToTask = False, verbose = False):
0263     """
0264     ------------------------------------------------------------------
0265     copy a sequence plus the modules and sequences therein
0266     both are renamed by getting a postfix
0267     input tags are automatically adjusted
0268     ------------------------------------------------------------------
0269     """
0270     result = sequence
0271     if not postfix == "":
0272         visitor = CloneSequenceVisitor(process, sequence.label(), postfix, removePostfix, noClones, addToTask, verbose)
0273         sequence.visit(visitor)
0274         result = visitor.clonedSequence()
0275     return result
0276 
0277 def listDependencyChain(process, module, sources, verbose=False):
0278     """
0279     Walk up the dependencies of a module to find any that depend on any of the listed sources
0280     """
0281     def allDirectInputModules(moduleOrPSet,moduleName,attrName):
0282         ret = set()
0283         for name,value in moduleOrPSet.parameters_().items():
0284             type = value.pythonTypeName()
0285             if type == 'cms.PSet':
0286                 ret.update(allDirectInputModules(value,moduleName,moduleName+"."+name))
0287             elif type == 'cms.VPSet':
0288                 for (i,ps) in enumerate(value):
0289                     ret.update(allDirectInputModules(ps,moduleName,"%s.%s[%d]"%(moduleName,name,i)))
0290             elif type == 'cms.VInputTag':
0291                 inputs = [ MassSearchReplaceAnyInputTagVisitor.standardizeInputTagFmt(it) for it in value ]
0292                 inputLabels = [ tag.moduleLabel for tag in inputs if tag.processName == '' or tag.processName == process.name_() ]
0293                 ret.update(inputLabels)
0294                 if verbose and inputLabels: print("%s depends on %s via %s" % (moduleName, inputLabels, attrName+"."+name))
0295             elif type.endswith('.InputTag'):
0296                 if value.processName == '' or value.processName == process.name_():
0297                     ret.add(value.moduleLabel)
0298                     if verbose: print("%s depends on %s via %s" % (moduleName, value.moduleLabel, attrName+"."+name))
0299         ret.discard("")
0300         return ret
0301     def fillDirectDepGraphs(root,fwdepgraph,revdepgraph):
0302         if root.label_() in fwdepgraph: return
0303         deps = allDirectInputModules(root,root.label_(),root.label_())
0304         fwdepgraph[root.label_()] = []
0305         for d in deps:        
0306             fwdepgraph[root.label_()].append(d)
0307             if d not in revdepgraph: revdepgraph[d] = []
0308             revdepgraph[d].append(root.label_())
0309             depmodule = getattr(process,d,None)
0310             if depmodule:
0311                 fillDirectDepGraphs(depmodule,fwdepgraph,revdepgraph)
0312         return (fwdepgraph,revdepgraph)
0313     fwdepgraph, revdepgraph = fillDirectDepGraphs(module, {}, {})
0314     def flattenRevDeps(flatgraph, revdepgraph, tip):
0315         """Make a graph that for each module lists all the ones that depend on it, directly or indirectly"""
0316         # don't do it multiple times for the same module
0317         if tip in flatgraph: return 
0318         # if nobody depends on this module, there's nothing to do
0319         if tip not in revdepgraph: return
0320         # assemble my dependencies, in a depth-first approach
0321         mydeps = set()
0322         # start taking the direct dependencies of this module
0323         for d in revdepgraph[tip]:
0324             # process them
0325             flattenRevDeps(flatgraph, revdepgraph, d)
0326             # then add them and their processed dependencies to our deps
0327             mydeps.add(d)
0328             if d in flatgraph: 
0329                 mydeps.update(flatgraph[d])
0330         flatgraph[tip] = mydeps
0331     flatdeps = {}
0332     allmodules = set()
0333     for s in sources: 
0334         flattenRevDeps(flatdeps, revdepgraph, s)
0335         if s in flatdeps: allmodules.update(f for f in flatdeps[s])
0336     livemodules = [ a for a in allmodules if hasattr(process,a) ]
0337     if not livemodules: return None
0338     modulelist = [livemodules.pop()]
0339     for module in livemodules:
0340         for i,m in enumerate(modulelist):
0341             if module in flatdeps and m in flatdeps[module]:
0342                 modulelist.insert(i, module)
0343                 break
0344         if module not in modulelist:
0345             modulelist.append(module)
0346     # Validate
0347     for i,m1 in enumerate(modulelist):
0348         for j,m2 in enumerate(modulelist):
0349             if j <= i: continue
0350             if m2 in flatdeps and m1 in flatdeps[m2]:
0351                 raise RuntimeError("BAD ORDER %s BEFORE %s" % (m1,m2))
0352     modules = [ getattr(process,p) for p in modulelist ]
0353     #return cms.Sequence(sum(modules[1:],modules[0]))
0354     task = cms.Task()
0355     for mod in modules: 
0356         task.add(mod)
0357     return task,cms.Sequence(task)
0358 
0359 def addKeepStatement(process, oldKeep, newKeeps, verbose=False):
0360     """Add new keep statements to any PoolOutputModule of the process that has the old keep statements"""
0361     for name,out in process.outputModules.items():
0362         if out.type_() == 'PoolOutputModule' and hasattr(out, "outputCommands"):
0363             if oldKeep in out.outputCommands:
0364                 out.outputCommands += newKeeps
0365             if verbose:
0366                 print("Adding the following keep statements to output module %s: " % name)
0367                 for k in newKeeps: print("\t'%s'," % k)
0368 
0369 
0370 if __name__=="__main__":
0371     import unittest
0372     def _lineDiff(newString, oldString):
0373         newString = ( x for x in newString.split('\n') if len(x) > 0)
0374         oldString = [ x for x in oldString.split('\n') if len(x) > 0]
0375         diff = []
0376         oldStringLine = 0
0377         for l in newString:
0378             if oldStringLine >= len(oldString):
0379                 diff.append(l)
0380                 continue
0381             if l == oldString[oldStringLine]:
0382                 oldStringLine +=1
0383                 continue
0384             diff.append(l)
0385         return "\n".join( diff )
0386 
0387     class TestModuleCommand(unittest.TestCase):
0388         def setUp(self):
0389             """Nothing to do """
0390             pass
0391         def testCloning(self):
0392             p = cms.Process("test")
0393             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0394             p.b = cms.EDProducer("b", src=cms.InputTag("a"))
0395             p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
0396             p.s = cms.Sequence(p.a*p.b*p.c *p.a)
0397             cloneProcessingSnippet(p, p.s, "New", addToTask = True)
0398             self.assertEqual(_lineDiff(p.dumpPython(), cms.Process("test").dumpPython()),
0399  """process.a = cms.EDProducer("a",
0400     src = cms.InputTag("gen")
0401 )
0402 process.aNew = cms.EDProducer("a",
0403     src = cms.InputTag("gen")
0404 )
0405 process.b = cms.EDProducer("b",
0406     src = cms.InputTag("a")
0407 )
0408 process.bNew = cms.EDProducer("b",
0409     src = cms.InputTag("aNew")
0410 )
0411 process.c = cms.EDProducer("c",
0412     src = cms.InputTag("b","instance")
0413 )
0414 process.cNew = cms.EDProducer("c",
0415     src = cms.InputTag("bNew","instance")
0416 )
0417 process.patAlgosToolsTask = cms.Task(process.aNew, process.bNew, process.cNew)
0418 process.s = cms.Sequence(process.a+process.b+process.c+process.a)
0419 process.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew+process.aNew)""")
0420         def testContains(self):
0421             p = cms.Process("test")
0422             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0423             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0424             p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
0425             p.s1 = cms.Sequence(p.a*p.b*p.c)
0426             p.s2 = cms.Sequence(p.b*p.c)
0427             self.assertTrue( contains(p.s1, "a") )
0428             self.assertTrue( not contains(p.s2, "a") )
0429         def testJetCollectionString(self):
0430             self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patJetsFooBar')
0431             self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatJetsFooBar')
0432         def testListModules(self):
0433             p = cms.Process("test")
0434             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0435             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0436             p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
0437             p.s = cms.Sequence(p.a*p.b*p.c)
0438             self.assertEqual([p.a,p.b,p.c], listModules(p.s))
0439 
0440     unittest.main()
0441 
0442 class CloneTaskVisitor(object):
0443     """Visitor that travels within a cms.Task, and returns a cloned version of the Task.
0444     All modules are cloned and a postfix is added"""
0445     def __init__(self, process, label, postfix, removePostfix="", noClones = [], verbose = False):
0446         self._process = process
0447         self._postfix = postfix
0448         self._removePostfix = removePostfix
0449         self._noClones = noClones
0450         self._verbose = verbose
0451         self._moduleLabels = []
0452         self._clonedTask = cms.Task()
0453         setattr(process, self._newLabel(label), self._clonedTask)
0454 
0455     def enter(self, visitee):
0456         if isinstance(visitee, cms._Module):
0457             label = visitee.label()
0458             newModule = None
0459             if label in self._noClones: #keep unchanged
0460                 newModule = getattr(self._process, label)
0461             elif label in self._moduleLabels: # has the module already been cloned ?
0462                 newModule = getattr(self._process, self._newLabel(label))
0463             else:
0464                 self._moduleLabels.append(label)
0465                 newModule = visitee.clone()
0466                 setattr(self._process, self._newLabel(label), newModule)
0467             self.__appendToTopTask(newModule)
0468 
0469     def leave(self, visitee):
0470         pass
0471 
0472     def clonedTask(self):#FIXME: can the following be used for Task?
0473         for label in self._moduleLabels:
0474             massSearchReplaceAnyInputTag(self._clonedTask, label, self._newLabel(label), moduleLabelOnly=True, verbose=self._verbose)
0475         self._moduleLabels = [] # prevent the InputTag replacement next time the 'clonedTask' function is called.
0476         return self._clonedTask
0477 
0478     def _newLabel(self, label):
0479         if self._removePostfix != "":
0480             if label[-len(self._removePostfix):] == self._removePostfix:
0481                 label = label[0:-len(self._removePostfix)]
0482             else:
0483                 raise Exception("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
0484         return label + self._postfix
0485 
0486     def __appendToTopTask(self, visitee):
0487         self._clonedTask.add(visitee)
0488 
0489 def cloneProcessingSnippetTask(process, task, postfix, removePostfix="", noClones = [], verbose = False):
0490     """
0491     ------------------------------------------------------------------
0492     copy a task plus the modules therein (including modules in subtasks)
0493     both are renamed by getting a postfix
0494     input tags are automatically adjusted
0495     ------------------------------------------------------------------
0496     """
0497     result = task
0498     if not postfix == "":
0499         visitor = CloneTaskVisitor(process, task.label(), postfix, removePostfix, noClones, verbose)
0500         task.visit(visitor)
0501         result = visitor.clonedTask()
0502     return result