Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-08-12 23:11:47

0001 from __future__ import print_function
0002 import FWCore.ParameterSet.Config as cms
0003 import sys
0004 
0005 ## Helpers to perform some technically boring tasks like looking for all modules with a given parameter
0006 ## and replacing that to a given value
0007 
0008 # Next two lines are for backward compatibility, the imported functions and
0009 # classes used to be defined in this file.
0010 from FWCore.ParameterSet.MassReplace import massSearchReplaceAnyInputTag, MassSearchReplaceAnyInputTagVisitor
0011 from FWCore.ParameterSet.MassReplace import massSearchReplaceParam, MassSearchParamVisitor, MassSearchReplaceParamVisitor
0012 
0013 def getPatAlgosToolsTask(process):
0014     taskName = "patAlgosToolsTask"
0015     if hasattr(process, taskName):
0016         task = getattr(process, taskName)
0017         if not isinstance(task, cms.Task):
0018             raise Exception("patAlgosToolsTask does not have type Task")
0019     else:
0020         setattr(process, taskName, cms.Task())
0021         task = getattr(process, taskName)
0022     return task
0023 
0024 def associatePatAlgosToolsTask(process):
0025     task = getPatAlgosToolsTask(process)
0026     process.schedule.associate(task)
0027 
0028 def addToProcessAndTask(label, module, process, task):
0029     setattr(process, label, module)
0030     task.add(getattr(process, label))
0031 
0032 def addESProducers(process,config):
0033     config = config.replace("/",".")
0034     #import RecoBTag.Configuration.RecoBTag_cff as btag
0035     #print btag
0036     module = __import__(config)
0037     for name in dir(sys.modules[config]):
0038         item = getattr(sys.modules[config],name)
0039         if isinstance(item,cms._Labelable) and not isinstance(item,cms._ModuleSequenceType) and not name.startswith('_') and not (name == "source" or name == "looper" or name == "subProcess") and not isinstance(item, cms.PSet):
0040             if 'ESProducer' in item.type_():
0041                 setattr(process,name,item)
0042 
0043 def loadWithPrefix(process,moduleName,prefix='',loadedProducersAndFilters=None):
0044     loadWithPrePostfix(process,moduleName,prefix,'',loadedProducersAndFilters)
0045 
0046 def loadWithPostfix(process,moduleName,postfix='',loadedProducersAndFilters=None):
0047     loadWithPrePostfix(process,moduleName,'',postfix,loadedProducersAndFilters)
0048 
0049 def loadWithPrePostfix(process,moduleName,prefix='',postfix='',loadedProducersAndFilters=None):
0050     moduleName = moduleName.replace("/",".")
0051     module = __import__(moduleName)
0052     #print module.PatAlgos.patSequences_cff.patDefaultSequence
0053     extendWithPrePostfix(process,sys.modules[moduleName],prefix,postfix,loadedProducersAndFilters)
0054 
0055 def addToTask(loadedProducersAndFilters, module):
0056     if loadedProducersAndFilters:
0057         if isinstance(module, cms.EDProducer) or isinstance(module, cms.EDFilter):
0058             loadedProducersAndFilters.add(module)
0059 
0060 def extendWithPrePostfix(process,other,prefix,postfix,loadedProducersAndFilters=None):
0061     """Look in other and find types which we can use"""
0062     # enable explicit check to avoid overwriting of existing objects
0063     #__dict__['_Process__InExtendCall'] = True
0064 
0065     if loadedProducersAndFilters:
0066         task = getattr(process, loadedProducersAndFilters)
0067         if not isinstance(task, cms.Task):
0068             raise Exception("extendWithPrePostfix argument must be name of Task type object attached to the process or None")
0069     else:
0070         task = None
0071 
0072     sequence = cms.Sequence()
0073     sequence._moduleLabels = []
0074     for name in dir(other):
0075         #'from XX import *' ignores these, and so should we.
0076         if name.startswith('_'):
0077             continue
0078         item = getattr(other,name)
0079         if name == "source" or name == "looper" or name == "subProcess":
0080             continue
0081         elif isinstance(item,cms._ModuleSequenceType):
0082             continue
0083         elif isinstance(item,cms.Task):
0084             continue
0085         elif isinstance(item,cms.Schedule):
0086             continue
0087         elif isinstance(item,cms.VPSet) or isinstance(item,cms.PSet):
0088             continue
0089         elif isinstance(item,cms._Labelable):
0090             if not item.hasLabel_():
0091                 item.setLabel(name)
0092             if prefix != '' or postfix != '':
0093                 newModule = item.clone()
0094                 if isinstance(item,cms.ESProducer):
0095                     newName =name
0096                 else:
0097                     if 'TauDiscrimination' in name:
0098                         process.__setattr__(name,item)
0099                         addToTask(task, item)
0100                     newName = prefix+name+postfix
0101                 process.__setattr__(newName,newModule)
0102                 addToTask(task, newModule)
0103                 if isinstance(newModule, cms._Sequenceable) and not newName == name:
0104                     sequence +=getattr(process,newName)
0105                     sequence._moduleLabels.append(item.label())
0106             else:
0107                 process.__setattr__(name,item)
0108                 addToTask(task, item)
0109 
0110     if prefix != '' or postfix != '':
0111         for label in sequence._moduleLabels:
0112             massSearchReplaceAnyInputTag(sequence, label, prefix+label+postfix,verbose=False,moduleLabelOnly=True)
0113 
0114 def applyPostfix(process, label, postfix):
0115     result = None
0116     if hasattr(process, label+postfix):
0117         result = getattr(process, label + postfix)
0118     else:
0119         raise ValueError("Error in <applyPostfix>: No module of name = %s attached to process !!" % (label + postfix))
0120     return result
0121 
0122 def removeIfInSequence(process, target,  sequenceLabel, postfix=""):
0123     labels = __labelsInSequence(process, sequenceLabel, postfix, True)
0124     if target+postfix in labels:
0125         getattr(process, sequenceLabel+postfix).remove(
0126             getattr(process, target+postfix)
0127             )
0128 
0129 def __labelsInSequence(process, sequenceLabel, postfix="", keepPostFix=False):
0130     position = -len(postfix)
0131     if keepPostFix: 
0132         position = None
0133 
0134     result = [ m.label()[:position] for m in listModules( getattr(process,sequenceLabel+postfix))]
0135     result.extend([ m.label()[:position] for m in listSequences( getattr(process,sequenceLabel+postfix))]  )
0136     if postfix == "":
0137         result = [ m.label() for m in listModules( getattr(process,sequenceLabel+postfix))]
0138         result.extend([ m.label() for m in listSequences( getattr(process,sequenceLabel+postfix))]  )
0139     return result
0140 
0141 #FIXME name is not generic enough now
0142 class GatherAllModulesVisitor(object):
0143     """Visitor that travels within a cms.Sequence, and returns a list of objects of type gatheredInance(e.g. modules) that have it"""
0144     def __init__(self, gatheredInstance=cms._Module):
0145         self._modules = []
0146         self._gatheredInstance= gatheredInstance
0147     def enter(self,visitee):
0148         if isinstance(visitee,self._gatheredInstance):
0149             self._modules.append(visitee)
0150     def leave(self,visitee):
0151         pass
0152     def modules(self):
0153         return self._modules
0154 
0155 class CloneSequenceVisitor(object):
0156     """Visitor that travels within a cms.Sequence, and returns a cloned version of the Sequence.
0157     All modules and sequences are cloned and a postfix is added"""
0158     def __init__(self, process, label, postfix, removePostfix="", noClones = [], addToTask = False, verbose = False):
0159         self._process = process
0160         self._postfix = postfix
0161         self._removePostfix = removePostfix
0162         self._noClones = noClones
0163         self._addToTask = addToTask
0164         self._verbose = verbose
0165         self._moduleLabels = []
0166         self._clonedSequence = cms.Sequence()
0167         setattr(process, self._newLabel(label), self._clonedSequence)
0168         if addToTask:
0169             self._patAlgosToolsTask = getPatAlgosToolsTask(process)
0170 
0171     def enter(self, visitee):
0172         if isinstance(visitee, cms._Module):
0173             label = visitee.label()
0174             newModule = None
0175             if label in self._noClones: #keep unchanged
0176                 newModule = getattr(self._process, label)
0177             elif label in self._moduleLabels: # has the module already been cloned ?
0178                 newModule = getattr(self._process, self._newLabel(label))
0179             else:
0180                 self._moduleLabels.append(label)
0181                 newModule = visitee.clone()
0182                 setattr(self._process, self._newLabel(label), newModule)
0183                 if self._addToTask:
0184                     self._patAlgosToolsTask.add(getattr(self._process, self._newLabel(label)))
0185             self.__appendToTopSequence(newModule)
0186 
0187     def leave(self, visitee):
0188         pass
0189 
0190     def clonedSequence(self):
0191         for label in self._moduleLabels:
0192             massSearchReplaceAnyInputTag(self._clonedSequence, label, self._newLabel(label), moduleLabelOnly=True, verbose=self._verbose)
0193         self._moduleLabels = [] # prevent the InputTag replacement next time the 'clonedSequence' function is called.
0194         return self._clonedSequence
0195 
0196     def _newLabel(self, label):
0197         if self._removePostfix != "":
0198             if label[-len(self._removePostfix):] == self._removePostfix:
0199                 label = label[0:-len(self._removePostfix)]
0200             else:
0201                 raise Exception("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
0202         return label + self._postfix
0203 
0204     def __appendToTopSequence(self, visitee):
0205         self._clonedSequence += visitee
0206 
0207 def listModules(sequence):
0208     visitor = GatherAllModulesVisitor(gatheredInstance=cms._Module)
0209     sequence.visit(visitor)
0210     return visitor.modules()
0211 
0212 def listSequences(sequence):
0213     visitor = GatherAllModulesVisitor(gatheredInstance=cms.Sequence)
0214     sequence.visit(visitor)
0215     return visitor.modules()
0216 
0217 def jetCollectionString(prefix='', algo='', type=''):
0218     """
0219     ------------------------------------------------------------------
0220     return the string of the jet collection module depending on the
0221     input vaules. The default return value will be 'patAK5CaloJets'.
0222 
0223     algo   : indicating the algorithm type of the jet [expected are
0224              'AK5', 'IC5', 'SC7', ...]
0225     type   : indicating the type of constituents of the jet [expec-
0226              ted are 'Calo', 'PFlow', 'JPT', ...]
0227     prefix : prefix indicating the type of pat collection module (ex-
0228              pected are '', 'selected', 'clean').
0229     ------------------------------------------------------------------
0230     """
0231     if(prefix==''):
0232         jetCollectionString ='pat'
0233     else:
0234         jetCollectionString =prefix
0235         jetCollectionString+='Pat'
0236     jetCollectionString+='Jets'
0237     jetCollectionString+=algo
0238     jetCollectionString+=type
0239     return jetCollectionString
0240 
0241 def contains(sequence, moduleName):
0242     """
0243     ------------------------------------------------------------------
0244     return True if a module with name 'module' is contained in the
0245     sequence with name 'sequence' and False otherwise. This version
0246     is not so nice as it also returns True for any substr of the name
0247     of a contained module.
0248 
0249     sequence : sequence [e.g. process.patDefaultSequence]
0250     module   : module name as a string
0251     ------------------------------------------------------------------
0252     """
0253     return not sequence.__str__().find(moduleName)==-1
0254 
0255 
0256 
0257 def cloneProcessingSnippet(process, sequence, postfix, removePostfix="", noClones = [], addToTask = False, verbose = False):
0258     """
0259     ------------------------------------------------------------------
0260     copy a sequence plus the modules and sequences therein
0261     both are renamed by getting a postfix
0262     input tags are automatically adjusted
0263     ------------------------------------------------------------------
0264     """
0265     result = sequence
0266     if not postfix == "":
0267         visitor = CloneSequenceVisitor(process, sequence.label(), postfix, removePostfix, noClones, addToTask, verbose)
0268         sequence.visit(visitor)
0269         result = visitor.clonedSequence()
0270     return result
0271 
0272 def listDependencyChain(process, module, sources, verbose=False):
0273     """
0274     Walk up the dependencies of a module to find any that depend on any of the listed sources
0275     """
0276     def allDirectInputModules(moduleOrPSet,moduleName,attrName):
0277         ret = set()
0278         for name,value in moduleOrPSet.parameters_().items():
0279             type = value.pythonTypeName()
0280             if type == 'cms.PSet':
0281                 ret.update(allDirectInputModules(value,moduleName,moduleName+"."+name))
0282             elif type == 'cms.VPSet':
0283                 for (i,ps) in enumerate(value):
0284                     ret.update(allDirectInputModules(ps,moduleName,"%s.%s[%d]"%(moduleName,name,i)))
0285             elif type == 'cms.VInputTag':
0286                 inputs = [ MassSearchReplaceAnyInputTagVisitor.standardizeInputTagFmt(it) for it in value ]
0287                 inputLabels = [ tag.moduleLabel for tag in inputs if tag.processName == '' or tag.processName == process.name_() ]
0288                 ret.update(inputLabels)
0289                 if verbose and inputLabels: print("%s depends on %s via %s" % (moduleName, inputLabels, attrName+"."+name))
0290             elif type.endswith('.InputTag'):
0291                 if value.processName == '' or value.processName == process.name_():
0292                     ret.add(value.moduleLabel)
0293                     if verbose: print("%s depends on %s via %s" % (moduleName, value.moduleLabel, attrName+"."+name))
0294         ret.discard("")
0295         return ret
0296     def fillDirectDepGraphs(root,fwdepgraph,revdepgraph):
0297         if root.label_() in fwdepgraph: return
0298         deps = allDirectInputModules(root,root.label_(),root.label_())
0299         fwdepgraph[root.label_()] = []
0300         for d in deps:        
0301             fwdepgraph[root.label_()].append(d)
0302             if d not in revdepgraph: revdepgraph[d] = []
0303             revdepgraph[d].append(root.label_())
0304             depmodule = getattr(process,d,None)
0305             if depmodule:
0306                 fillDirectDepGraphs(depmodule,fwdepgraph,revdepgraph)
0307         return (fwdepgraph,revdepgraph)
0308     fwdepgraph, revdepgraph = fillDirectDepGraphs(module, {}, {})
0309     def flattenRevDeps(flatgraph, revdepgraph, tip):
0310         """Make a graph that for each module lists all the ones that depend on it, directly or indirectly"""
0311         # don't do it multiple times for the same module
0312         if tip in flatgraph: return 
0313         # if nobody depends on this module, there's nothing to do
0314         if tip not in revdepgraph: return
0315         # assemble my dependencies, in a depth-first approach
0316         mydeps = set()
0317         # start taking the direct dependencies of this module
0318         for d in revdepgraph[tip]:
0319             # process them
0320             flattenRevDeps(flatgraph, revdepgraph, d)
0321             # then add them and their processed dependencies to our deps
0322             mydeps.add(d)
0323             if d in flatgraph: 
0324                 mydeps.update(flatgraph[d])
0325         flatgraph[tip] = mydeps
0326     flatdeps = {}
0327     allmodules = set()
0328     for s in sources: 
0329         flattenRevDeps(flatdeps, revdepgraph, s)
0330         if s in flatdeps: allmodules.update(f for f in flatdeps[s])
0331     livemodules = [ a for a in allmodules if hasattr(process,a) ]
0332     if not livemodules: return None
0333     modulelist = [livemodules.pop()]
0334     for module in livemodules:
0335         for i,m in enumerate(modulelist):
0336             if module in flatdeps and m in flatdeps[module]:
0337                 modulelist.insert(i, module)
0338                 break
0339         if module not in modulelist:
0340             modulelist.append(module)
0341     # Validate
0342     for i,m1 in enumerate(modulelist):
0343         for j,m2 in enumerate(modulelist):
0344             if j <= i: continue
0345             if m2 in flatdeps and m1 in flatdeps[m2]:
0346                 raise RuntimeError("BAD ORDER %s BEFORE %s" % (m1,m2))
0347     modules = [ getattr(process,p) for p in modulelist ]
0348     #return cms.Sequence(sum(modules[1:],modules[0]))
0349     task = cms.Task()
0350     for mod in modules: 
0351         task.add(mod)
0352     return task,cms.Sequence(task)
0353 
0354 def addKeepStatement(process, oldKeep, newKeeps, verbose=False):
0355     """Add new keep statements to any PoolOutputModule of the process that has the old keep statements"""
0356     for name,out in process.outputModules.items():
0357         if out.type_() == 'PoolOutputModule' and hasattr(out, "outputCommands"):
0358             if oldKeep in out.outputCommands:
0359                 out.outputCommands += newKeeps
0360             if verbose:
0361                 print("Adding the following keep statements to output module %s: " % name)
0362                 for k in newKeeps: print("\t'%s'," % k)
0363 
0364 
0365 if __name__=="__main__":
0366     import unittest
0367     def _lineDiff(newString, oldString):
0368         newString = ( x for x in newString.split('\n') if len(x) > 0)
0369         oldString = [ x for x in oldString.split('\n') if len(x) > 0]
0370         diff = []
0371         oldStringLine = 0
0372         for l in newString:
0373             if oldStringLine >= len(oldString):
0374                 diff.append(l)
0375                 continue
0376             if l == oldString[oldStringLine]:
0377                 oldStringLine +=1
0378                 continue
0379             diff.append(l)
0380         return "\n".join( diff )
0381 
0382     class TestModuleCommand(unittest.TestCase):
0383         def setUp(self):
0384             """Nothing to do """
0385             pass
0386         def testCloning(self):
0387             p = cms.Process("test")
0388             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0389             p.b = cms.EDProducer("b", src=cms.InputTag("a"))
0390             p.c = cms.EDProducer("c", src=cms.InputTag("b","instance"))
0391             p.s = cms.Sequence(p.a*p.b*p.c *p.a)
0392             cloneProcessingSnippet(p, p.s, "New", addToTask = True)
0393             self.assertEqual(_lineDiff(p.dumpPython(), cms.Process("test").dumpPython()),
0394  """process.a = cms.EDProducer("a",
0395     src = cms.InputTag("gen")
0396 )
0397 process.aNew = cms.EDProducer("a",
0398     src = cms.InputTag("gen")
0399 )
0400 process.b = cms.EDProducer("b",
0401     src = cms.InputTag("a")
0402 )
0403 process.bNew = cms.EDProducer("b",
0404     src = cms.InputTag("aNew")
0405 )
0406 process.c = cms.EDProducer("c",
0407     src = cms.InputTag("b","instance")
0408 )
0409 process.cNew = cms.EDProducer("c",
0410     src = cms.InputTag("bNew","instance")
0411 )
0412 process.patAlgosToolsTask = cms.Task(process.aNew, process.bNew, process.cNew)
0413 process.s = cms.Sequence(process.a+process.b+process.c+process.a)
0414 process.sNew = cms.Sequence(process.aNew+process.bNew+process.cNew+process.aNew)""")
0415         def testContains(self):
0416             p = cms.Process("test")
0417             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0418             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0419             p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
0420             p.s1 = cms.Sequence(p.a*p.b*p.c)
0421             p.s2 = cms.Sequence(p.b*p.c)
0422             self.assertTrue( contains(p.s1, "a") )
0423             self.assertTrue( not contains(p.s2, "a") )
0424         def testJetCollectionString(self):
0425             self.assertEqual(jetCollectionString(algo = 'Foo', type = 'Bar'), 'patJetsFooBar')
0426             self.assertEqual(jetCollectionString(prefix = 'prefix', algo = 'Foo', type = 'Bar'), 'prefixPatJetsFooBar')
0427         def testListModules(self):
0428             p = cms.Process("test")
0429             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0430             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0431             p.c = cms.EDProducer("ac", src=cms.InputTag("b"))
0432             p.s = cms.Sequence(p.a*p.b*p.c)
0433             self.assertEqual([p.a,p.b,p.c], listModules(p.s))
0434 
0435     unittest.main()
0436 
0437 class CloneTaskVisitor(object):
0438     """Visitor that travels within a cms.Task, and returns a cloned version of the Task.
0439     All modules are cloned and a postfix is added"""
0440     def __init__(self, process, label, postfix, removePostfix="", noClones = [], verbose = False):
0441         self._process = process
0442         self._postfix = postfix
0443         self._removePostfix = removePostfix
0444         self._noClones = noClones
0445         self._verbose = verbose
0446         self._moduleLabels = []
0447         self._clonedTask = cms.Task()
0448         setattr(process, self._newLabel(label), self._clonedTask)
0449 
0450     def enter(self, visitee):
0451         if isinstance(visitee, cms._Module):
0452             label = visitee.label()
0453             newModule = None
0454             if label in self._noClones: #keep unchanged
0455                 newModule = getattr(self._process, label)
0456             elif label in self._moduleLabels: # has the module already been cloned ?
0457                 newModule = getattr(self._process, self._newLabel(label))
0458             else:
0459                 self._moduleLabels.append(label)
0460                 newModule = visitee.clone()
0461                 setattr(self._process, self._newLabel(label), newModule)
0462             self.__appendToTopTask(newModule)
0463 
0464     def leave(self, visitee):
0465         pass
0466 
0467     def clonedTask(self):#FIXME: can the following be used for Task?
0468         for label in self._moduleLabels:
0469             massSearchReplaceAnyInputTag(self._clonedTask, label, self._newLabel(label), moduleLabelOnly=True, verbose=self._verbose)
0470         self._moduleLabels = [] # prevent the InputTag replacement next time the 'clonedTask' function is called.
0471         return self._clonedTask
0472 
0473     def _newLabel(self, label):
0474         if self._removePostfix != "":
0475             if label[-len(self._removePostfix):] == self._removePostfix:
0476                 label = label[0:-len(self._removePostfix)]
0477             else:
0478                 raise Exception("Tried to remove postfix %s from label %s, but it wasn't there" % (self._removePostfix, label))
0479         return label + self._postfix
0480 
0481     def __appendToTopTask(self, visitee):
0482         self._clonedTask.add(visitee)
0483 
0484 def cloneProcessingSnippetTask(process, task, postfix, removePostfix="", noClones = [], verbose = False):
0485     """
0486     ------------------------------------------------------------------
0487     copy a task plus the modules therein (including modules in subtasks)
0488     both are renamed by getting a postfix
0489     input tags are automatically adjusted
0490     ------------------------------------------------------------------
0491     """
0492     result = task
0493     if not postfix == "":
0494         visitor = CloneTaskVisitor(process, task.label(), postfix, removePostfix, noClones, verbose)
0495         task.visit(visitor)
0496         result = visitor.clonedTask()
0497     return result