Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:23:57

0001 from __future__ import print_function
0002 import FWCore.ParameterSet.Config as cms
0003 import sys
0004 
0005 ## Helpers to perform some technically boring tasks like looking for all modules with a given parameter
0006 ## and replacing that to a given value
0007 
0008 # Next two lines are for backward compatibility, the imported functions and
0009 # classes used to be defined in this file.
0010 from FWCore.ParameterSet.MassReplace import massSearchReplaceAnyInputTag, MassSearchReplaceAnyInputTagVisitor
0011 from FWCore.ParameterSet.MassReplace import massSearchReplaceParam, MassSearchParamVisitor, MassSearchReplaceParamVisitor
0012 
0013 def getPatAlgosToolsTask(process):
0014     taskName = "patAlgosToolsTask"
0015     if hasattr(process, taskName):
0016         task = getattr(process, taskName)
0017         if not isinstance(task, cms.Task):
0018             raise Exception("patAlgosToolsTask does not have type Task")
0019     else:
0020         setattr(process, taskName, cms.Task())
0021         task = getattr(process, taskName)
0022     return task
0023 
0024 def associatePatAlgosToolsTask(process):
0025     task = getPatAlgosToolsTask(process)
0026     process.schedule.associate(task)
0027 
0028 def addToProcessAndTask(label, module, process, task):
0029     setattr(process, label, module)
0030     task.add(getattr(process, label))
0031 
0032 def addTaskToProcess(process, label, task):
0033     if not hasattr(process, label):
0034         setattr(process, label, task)
0035     else:
0036         getattr(process, label).add(task)
0037 
0038 def addESProducers(process,config):
0039     config = config.replace("/",".")
0040     #import RecoBTag.Configuration.RecoBTag_cff as btag
0041     #print btag
0042     module = __import__(config)
0043     for name in dir(sys.modules[config]):
0044         item = getattr(sys.modules[config],name)
0045         if isinstance(item,cms._Labelable) and not isinstance(item,cms._ModuleSequenceType) and not name.startswith('_') and not (name == "source" or name == "looper" or name == "subProcess") and not isinstance(item, cms.PSet):
0046             if 'ESProducer' in item.type_():
0047                 setattr(process,name,item)
0048 
0049 def loadWithPrefix(process,moduleName,prefix='',loadedProducersAndFilters=None):
0050     loadWithPrePostfix(process,moduleName,prefix,'',loadedProducersAndFilters)
0051 
0052 def loadWithPostfix(process,moduleName,postfix='',loadedProducersAndFilters=None):
0053     loadWithPrePostfix(process,moduleName,'',postfix,loadedProducersAndFilters)
0054 
0055 def loadWithPrePostfix(process,moduleName,prefix='',postfix='',loadedProducersAndFilters=None):
0056     moduleName = moduleName.replace("/",".")
0057     module = __import__(moduleName)
0058     #print module.PatAlgos.patSequences_cff.patDefaultSequence
0059     extendWithPrePostfix(process,sys.modules[moduleName],prefix,postfix,loadedProducersAndFilters)
0060 
0061 def addToTask(loadedProducersAndFilters, module):
0062     if loadedProducersAndFilters:
0063         if isinstance(module, cms.EDProducer) or isinstance(module, cms.EDFilter):
0064             loadedProducersAndFilters.add(module)
0065 
0066 def extendWithPrePostfix(process,other,prefix,postfix,loadedProducersAndFilters=None):
0067     """Look in other and find types which we can use"""
0068     # enable explicit check to avoid overwriting of existing objects
0069     #__dict__['_Process__InExtendCall'] = True
0070 
0071     if loadedProducersAndFilters:
0072         task = getattr(process, loadedProducersAndFilters)
0073         if not isinstance(task, cms.Task):
0074             raise Exception("extendWithPrePostfix argument must be name of Task type object attached to the process or None")
0075     else:
0076         task = None
0077 
0078     sequence = cms.Sequence()
0079     sequence._moduleLabels = []
0080     for name in dir(other):
0081         #'from XX import *' ignores these, and so should we.
0082         if name.startswith('_'):
0083             continue
0084         item = getattr(other,name)
0085         if name == "source" or name == "looper" or name == "subProcess":
0086             continue
0087         elif isinstance(item,cms._ModuleSequenceType):
0088             continue
0089         elif isinstance(item,cms.Task):
0090             continue
0091         elif isinstance(item,cms.Schedule):
0092             continue
0093         elif isinstance(item,cms.VPSet) or isinstance(item,cms.PSet):
0094             continue
0095         elif isinstance(item,cms._Labelable):
0096             if not item.hasLabel_():
0097                 item.setLabel(name)
0098             if prefix != '' or postfix != '':
0099                 newModule = item.clone()
0100                 if isinstance(item,cms.ESProducer):
0101                     newName =name
0102                 else:
0103                     if 'TauDiscrimination' in name:
0104                         process.__setattr__(name,item)
0105                         addToTask(task, item)
0106                     newName = prefix+name+postfix
0107                 process.__setattr__(newName,newModule)
0108                 addToTask(task, newModule)
0109                 if isinstance(newModule, cms._Sequenceable) and not newName == name:
0110                     sequence +=getattr(process,newName)
0111                     sequence._moduleLabels.append(item.label())
0112             else:
0113                 process.__setattr__(name,item)
0114                 addToTask(task, item)
0115 
0116     if prefix != '' or postfix != '':
0117         for label in sequence._moduleLabels:
0118             massSearchReplaceAnyInputTag(sequence, label, prefix+label+postfix,verbose=False,moduleLabelOnly=True)
0119 
0120 def applyPostfix(process, label, postfix):
0121     result = None
0122     if hasattr(process, label+postfix):
0123         result = getattr(process, label + postfix)
0124     else:
0125         raise ValueError("Error in <applyPostfix>: No module of name = %s attached to process !!" % (label + postfix))
0126     return result
0127 
0128 def removeIfInSequence(process, target,  sequenceLabel, postfix=""):
0129     labels = __labelsInSequence(process, sequenceLabel, postfix, True)
0130     if target+postfix in labels:
0131         getattr(process, sequenceLabel+postfix).remove(
0132             getattr(process, target+postfix)
0133             )
0134 
0135 def __labelsInSequence(process, sequenceLabel, postfix="", keepPostFix=False):
0136     position = -len(postfix)
0137     if keepPostFix: 
0138         position = None
0139 
0140     result = [ m.label()[:position] for m in listModules( getattr(process,sequenceLabel+postfix))]
0141     result.extend([ m.label()[:position] for m in listSequences( getattr(process,sequenceLabel+postfix))]  )
0142     if postfix == "":
0143         result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
0144         result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))]  )
0145     return result
0146 
0147 #FIXME name is not generic enough now
0148 class GatherAllModulesVisitor(object):
0149     """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
0150     def __init__(self, gatheredInstance=cms._Module):
0151         self._modules = []
0152         self._gatheredInstance= gatheredInstance
0153     def enter(self,visitee):
0154         if isinstance(visitee,self._gatheredInstance):
0155             self._modules.append(visitee)
0156     def leave(self,visitee):
0157         pass
0158     def modules(self):
0159         return self._modules
0160 
0161 class CloneSequenceVisitor(object):
0162     """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
0163     All modules and sequences are cloned and a postfix is added"""
0164     def __init__(self, process, label, postfix, removePostfix="", noClones = [], addToTask = False, verbose = False):
0165         self._process = process
0166         self._postfix = postfix
0167         self._removePostfix = removePostfix
0168         self._noClones = noClones
0169         self._addToTask = addToTask
0170         self._verbose = verbose
0171         self._moduleLabels = []
0172         self._clonedSequence = cms.Sequence()
0173         setattr(process, self._newLabel(label), self._clonedSequence)
0174         if addToTask:
0175             self._patAlgosToolsTask = getPatAlgosToolsTask(process)
0176 
0177     def enter(self, visitee):
0178         if isinstance(visitee, cms._Module):
0179             label = visitee.label()
0180             newModule = None
0181             if label in self._noClones: #keep unchanged
0182                 newModule = getattr(self._process, label)
0183             elif label in self._moduleLabels: # has the module already been cloned ?
0184                 newModule = getattr(self._process, self._newLabel(label))
0185             else:
0186                 self._moduleLabels.append(label)
0187                 newModule = visitee.clone()
0188                 setattr(self._process, self._newLabel(label), newModule)
0189                 if self._addToTask:
0190                     self._patAlgosToolsTask.add(getattr(self._process, self._newLabel(label)))
0191             self.__appendToTopSequence(newModule)
0192 
0193     def leave(self, visitee):
0194         pass
0195 
0196     def clonedSequence(self):
0197         for label in self._moduleLabels:
0198             massSearchReplaceAnyInputTag(self._clonedSequence, label, self._newLabel(label), moduleLabelOnly=True, verbose=self._verbose)
0199         self._moduleLabels = [] # prevent the InputTag replacement next time the 'clonedSequence' function is called.
0200         return self._clonedSequence
0201 
0202     def _newLabel(self, label):
0203         if self._removePostfix != "":
0204             if label[-len(self._removePostfix):] == self._removePostfix:
0205                 label = label[0:-len(self._removePostfix)]
0206             else:
0207                 raise Exception("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
0208         return label + self._postfix
0209 
0210     def __appendToTopSequence(self, visitee):
0211         self._clonedSequence += visitee
0212 
0213 def listModules(sequence):
0214     visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
0215     sequence.visit(visitor)
0216     return visitor.modules()
0217 
0218 def listSequences(sequence):
0219     visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
0220     sequence.visit(visitor)
0221     return visitor.modules()
0222 
0223 def jetCollectionString(prefix='', algo='', type=''):
0224     """
0225     ------------------------------------------------------------------
0226     return the string of the jet collection module depending on the
0227     input vaules. The default return value will be 'patAK5CaloJets'.
0228 
0229     algo   : indicating the algorithm type of the jet [expected are
0230              'AK5', 'IC5', 'SC7', ...]
0231     type   : indicating the type of constituents of the jet [expec-
0232              ted are 'Calo', 'PFlow', 'JPT', ...]
0233     prefix : prefix indicating the type of pat collection module (ex-
0234              pected are '', 'selected', 'clean').
0235     ------------------------------------------------------------------
0236     """
0237     if(prefix==''):
0238         jetCollectionString ='pat'
0239     else:
0240         jetCollectionString =prefix
0241         jetCollectionString+='Pat'
0242     jetCollectionString+='Jets'
0243     jetCollectionString+=algo
0244     jetCollectionString+=type
0245     return jetCollectionString
0246 
0247 def contains(sequence, moduleName):
0248     """
0249     ------------------------------------------------------------------
0250     return True if a module with name 'module' is contained in the
0251     sequence with name 'sequence' and False otherwise. This version
0252     is not so nice as it also returns True for any substr of the name
0253     of a contained module.
0254 
0255     sequence : sequence [e.g. process.patDefaultSequence]
0256     module   : module name as a string
0257     ------------------------------------------------------------------
0258     """
0259     return not sequence.__str__().find(moduleName)==-1
0260 
0261 
0262 
0263 def cloneProcessingSnippet(process, sequence, postfix, removePostfix="", noClones = [], addToTask = False, verbose = False):
0264     """
0265     ------------------------------------------------------------------
0266     copy a sequence plus the modules and sequences therein
0267     both are renamed by getting a postfix
0268     input tags are automatically adjusted
0269     ------------------------------------------------------------------
0270     """
0271     result = sequence
0272     if not postfix == "":
0273         visitor = CloneSequenceVisitor(process, sequence.label(), postfix, removePostfix, noClones, addToTask, verbose)
0274         sequence.visit(visitor)
0275         result = visitor.clonedSequence()
0276     return result
0277 
0278 def listDependencyChain(process, module, sources, verbose=False):
0279     """
0280     Walk up the dependencies of a module to find any that depend on any of the listed sources
0281     """
0282     def allDirectInputModules(moduleOrPSet,moduleName,attrName):
0283         ret = set()
0284         for name,value in moduleOrPSet.parameters_().items():
0285             type = value.pythonTypeName()
0286             if type == 'cms.PSet':
0287                 ret.update(allDirectInputModules(value,moduleName,moduleName+"."+name))
0288             elif type == 'cms.VPSet':
0289                 for (i,ps) in enumerate(value):
0290                     ret.update(allDirectInputModules(ps,moduleName,"%s.%s[%d]"%(moduleName,name,i)))
0291             elif type == 'cms.VInputTag':
0292                 inputs = [ MassSearchReplaceAnyInputTagVisitor.standardizeInputTagFmt(it) for it in value ]
0293                 inputLabels = [ tag.moduleLabel for tag in inputs if tag.processName == '' or tag.processName == process.name_() ]
0294                 ret.update(inputLabels)
0295                 if verbose and inputLabels: print("%s depends on %s via %s" % (moduleName, inputLabels, attrName+"."+name))
0296             elif type.endswith('.InputTag'):
0297                 if value.processName == '' or value.processName == process.name_():
0298                     ret.add(value.moduleLabel)
0299                     if verbose: print("%s depends on %s via %s" % (moduleName, value.moduleLabel, attrName+"."+name))
0300         ret.discard("")
0301         return ret
0302     def fillDirectDepGraphs(root,fwdepgraph,revdepgraph):
0303         if root.label_() in fwdepgraph: return
0304         deps = allDirectInputModules(root,root.label_(),root.label_())
0305         fwdepgraph[root.label_()] = []
0306         for d in deps:        
0307             fwdepgraph[root.label_()].append(d)
0308             if d not in revdepgraph: revdepgraph[d] = []
0309             revdepgraph[d].append(root.label_())
0310             depmodule = getattr(process,d,None)
0311             if depmodule:
0312                 fillDirectDepGraphs(depmodule,fwdepgraph,revdepgraph)
0313         return (fwdepgraph,revdepgraph)
0314     fwdepgraph, revdepgraph = fillDirectDepGraphs(module, {}, {})
0315     def flattenRevDeps(flatgraph, revdepgraph, tip):
0316         """Make a graph that for each module lists all the ones that depend on it, directly or indirectly"""
0317         # don't do it multiple times for the same module
0318         if tip in flatgraph: return 
0319         # if nobody depends on this module, there's nothing to do
0320         if tip not in revdepgraph: return
0321         # assemble my dependencies, in a depth-first approach
0322         mydeps = set()
0323         # start taking the direct dependencies of this module
0324         for d in revdepgraph[tip]:
0325             # process them
0326             flattenRevDeps(flatgraph, revdepgraph, d)
0327             # then add them and their processed dependencies to our deps
0328             mydeps.add(d)
0329             if d in flatgraph: 
0330                 mydeps.update(flatgraph[d])
0331         flatgraph[tip] = mydeps
0332     flatdeps = {}
0333     allmodules = set()
0334     for s in sources: 
0335         flattenRevDeps(flatdeps, revdepgraph, s)
0336         if s in flatdeps: allmodules.update(f for f in flatdeps[s])
0337     livemodules = [ a for a in allmodules if hasattr(process,a) ]
0338     if not livemodules: return None
0339     modulelist = [livemodules.pop()]
0340     for module in livemodules:
0341         for i,m in enumerate(modulelist):
0342             if module in flatdeps and m in flatdeps[module]:
0343                 modulelist.insert(i, module)
0344                 break
0345         if module not in modulelist:
0346             modulelist.append(module)
0347     # Validate
0348     for i,m1 in enumerate(modulelist):
0349         for j,m2 in enumerate(modulelist):
0350             if j <= i: continue
0351             if m2 in flatdeps and m1 in flatdeps[m2]:
0352                 raise RuntimeError("BAD ORDER %s BEFORE %s" % (m1,m2))
0353     modules = [ getattr(process,p) for p in modulelist ]
0354     #return cms.Sequence(sum(modules[1:],modules[0]))
0355     task = cms.Task()
0356     for mod in modules: 
0357         task.add(mod)
0358     return task,cms.Sequence(task)
0359 
0360 def addKeepStatement(process, oldKeep, newKeeps, verbose=False):
0361     """Add new keep statements to any PoolOutputModule of the process that has the old keep statements"""
0362     for name,out in process.outputModules.items():
0363         if out.type_() == 'PoolOutputModule' and hasattr(out, "outputCommands"):
0364             if oldKeep in out.outputCommands:
0365                 out.outputCommands += newKeeps
0366             if verbose:
0367                 print("Adding the following keep statements to output module %s: " % name)
0368                 for k in newKeeps: print("\t'%s'," % k)
0369 
0370 
0371 if __name__=="__main__":
0372     import unittest
0373     def _lineDiff(newString, oldString):
0374         newString = ( x for x in newString.split('\n') if len(x) > 0)
0375         oldString = [ x for x in oldString.split('\n') if len(x) > 0]
0376         diff = []
0377         oldStringLine = 0
0378         for l in newString:
0379             if oldStringLine >= len(oldString):
0380                 diff.append(l)
0381                 continue
0382             if l == oldString[oldStringLine]:
0383                 oldStringLine +=1
0384                 continue
0385             diff.append(l)
0386         return "\n".join( diff )
0387 
0388     class TestModuleCommand(unittest.TestCase):
0389         def setUp(self):
0390             """Nothing to do """
0391             pass
0392         def testCloning(self):
0393             p = cms.Process("test")
0394             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0395             p.b = cms.EDProducer("b", src=cms.InputTag("a"))
0396             p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
0397             p.s = cms.Sequence(p.a*p.b*p.c *p.a)
0398             cloneProcessingSnippet(p, p.s, "New", addToTask = True)
0399             self.assertEqual(_lineDiff(p.dumpPython(), cms.Process("test").dumpPython()),
0400  """process.a = cms.EDProducer("a",
0401     src = cms.InputTag("gen")
0402 )
0403 process.aNew = cms.EDProducer("a",
0404     src = cms.InputTag("gen")
0405 )
0406 process.b = cms.EDProducer("b",
0407     src = cms.InputTag("a")
0408 )
0409 process.bNew = cms.EDProducer("b",
0410     src = cms.InputTag("aNew")
0411 )
0412 process.c = cms.EDProducer("c",
0413     src = cms.InputTag("b","instance")
0414 )
0415 process.cNew = cms.EDProducer("c",
0416     src = cms.InputTag("bNew","instance")
0417 )
0418 process.patAlgosToolsTask = cms.Task(process.aNew, process.bNew, process.cNew)
0419 process.s = cms.Sequence(process.a+process.b+process.c+process.a)
0420 process.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew+process.aNew)""")
0421         def testContains(self):
0422             p = cms.Process("test")
0423             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0424             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0425             p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
0426             p.s1 = cms.Sequence(p.a*p.b*p.c)
0427             p.s2 = cms.Sequence(p.b*p.c)
0428             self.assertTrue( contains(p.s1, "a") )
0429             self.assertTrue( not contains(p.s2, "a") )
0430         def testJetCollectionString(self):
0431             self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patJetsFooBar')
0432             self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatJetsFooBar')
0433         def testListModules(self):
0434             p = cms.Process("test")
0435             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0436             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0437             p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
0438             p.s = cms.Sequence(p.a*p.b*p.c)
0439             self.assertEqual([p.a,p.b,p.c], listModules(p.s))
0440 
0441     unittest.main()
0442 
0443 class CloneTaskVisitor(object):
0444     """Visitor that travels within a cms.Task, and returns a cloned version of the Task.
0445     All modules are cloned and a postfix is added"""
0446     def __init__(self, process, label, postfix, removePostfix="", noClones = [], verbose = False):
0447         self._process = process
0448         self._postfix = postfix
0449         self._removePostfix = removePostfix
0450         self._noClones = noClones
0451         self._verbose = verbose
0452         self._moduleLabels = []
0453         self._clonedTask = cms.Task()
0454         setattr(process, self._newLabel(label), self._clonedTask)
0455 
0456     def enter(self, visitee):
0457         if isinstance(visitee, cms._Module):
0458             label = visitee.label()
0459             newModule = None
0460             if label in self._noClones: #keep unchanged
0461                 newModule = getattr(self._process, label)
0462             elif label in self._moduleLabels: # has the module already been cloned ?
0463                 newModule = getattr(self._process, self._newLabel(label))
0464             else:
0465                 self._moduleLabels.append(label)
0466                 newModule = visitee.clone()
0467                 setattr(self._process, self._newLabel(label), newModule)
0468             self.__appendToTopTask(newModule)
0469 
0470     def leave(self, visitee):
0471         pass
0472 
0473     def clonedTask(self):#FIXME: can the following be used for Task?
0474         for label in self._moduleLabels:
0475             massSearchReplaceAnyInputTag(self._clonedTask, label, self._newLabel(label), moduleLabelOnly=True, verbose=self._verbose)
0476         self._moduleLabels = [] # prevent the InputTag replacement next time the 'clonedTask' function is called.
0477         return self._clonedTask
0478 
0479     def _newLabel(self, label):
0480         if self._removePostfix != "":
0481             if label[-len(self._removePostfix):] == self._removePostfix:
0482                 label = label[0:-len(self._removePostfix)]
0483             else:
0484                 raise Exception("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
0485         return label + self._postfix
0486 
0487     def __appendToTopTask(self, visitee):
0488         self._clonedTask.add(visitee)
0489 
0490 def cloneProcessingSnippetTask(process, task, postfix, removePostfix="", noClones = [], verbose = False):
0491     """
0492     ------------------------------------------------------------------
0493     copy a task plus the modules therein (including modules in subtasks)
0494     both are renamed by getting a postfix
0495     input tags are automatically adjusted
0496     ------------------------------------------------------------------
0497     """
0498     result = task
0499     if not postfix == "":
0500         visitor = CloneTaskVisitor(process, task.label(), postfix, removePostfix, noClones, verbose)
0501         task.visit(visitor)
0502         result = visitor.clonedTask()
0503     return result