Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-02-14 13:28:53

0001 from __future__ import print_function
0002 import FWCore.ParameterSet.Config as cms
0003 
0004 class MassSearchReplaceAnyInputTagVisitor(object):
0005     """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
0006        It will climb down within PSets, VPSets and VInputTags to find its target"""
0007     def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False,skipLabelTest=False):
0008         self._paramSearch  = self.standardizeInputTagFmt(paramSearch)
0009         self._paramReplace = self.standardizeInputTagFmt(paramReplace)
0010         self._moduleName   = ''
0011         self._verbose=verbose
0012         self._moduleLabelOnly=moduleLabelOnly
0013         self._skipLabelTest=skipLabelTest
0014     def doIt(self,pset,base):
0015         if isinstance(pset, cms._Parameterizable):
0016             for name in pset.parameterNames_():
0017                 # if I use pset.parameters_().items() I get copies of the parameter values
0018                 # so I can't modify the nested pset
0019                 value = getattr(pset,name)
0020                 if isinstance(value, cms.PSet) or isinstance(value, cms.EDProducer) or isinstance(value, cms.EDAlias):
0021                     # EDProducer and EDAlias to support SwitchProducer
0022                     self.doIt(value,base+"."+name)
0023                 elif value.isCompatibleCMSType(cms.VPSet):
0024                     for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
0025                 elif value.isCompatibleCMSType(cms.VInputTag) and value:
0026                     for (i,n) in enumerate(value):
0027                         # VInputTag can be declared as a list of strings, so ensure that n is formatted correctly
0028                         n = self.standardizeInputTagFmt(n)
0029                         if (n == self._paramSearch):
0030                             if self._verbose:print("Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace))
0031                             if not value.isTracked():
0032                                 value[i] = cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
0033                                                                   self._paramReplace.getProductInstanceLabel(),
0034                                                                   self._paramReplace.getProcessName())
0035                             else:
0036                                 value[i] = self._paramReplace
0037                         elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
0038                             nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
0039                             if self._verbose:print("Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep))
0040                             value[i] = nrep
0041                 elif value.isCompatibleCMSType(cms.InputTag) and value:
0042                     if value == self._paramSearch:
0043                         if self._verbose:print("Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace))
0044                         from copy import deepcopy
0045                         if not value.isTracked():
0046                             # the existing value should stay untracked even if the given parameter is tracked
0047                             setattr(pset, name, cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
0048                                                                        self._paramReplace.getProductInstanceLabel(),
0049                                                                        self._paramReplace.getProcessName()))
0050                         else:
0051                             setattr(pset, name, deepcopy(self._paramReplace) )
0052                     elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
0053                         from copy import deepcopy
0054                         repl = deepcopy(getattr(pset, name))
0055                         repl.moduleLabel = self._paramReplace.moduleLabel
0056                         setattr(pset, name, repl)
0057                         if self._verbose:print("Replace %s.%s %s ==> %s " % (base, name, value, repl))
0058 
0059 
0060     @staticmethod
0061     def standardizeInputTagFmt(inputTag):
0062         ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
0063         if not isinstance(inputTag, cms.InputTag):
0064             return cms.InputTag(inputTag)
0065         return inputTag
0066 
0067     def enter(self,visitee):
0068         label = ''
0069         if (not self._skipLabelTest):
0070             if hasattr(visitee,"hasLabel_") and visitee.hasLabel_():
0071                 label = visitee.label_()
0072             else: label = '<Module not in a Process>'
0073         else:
0074             label = '<Module label not tested>'
0075         self.doIt(visitee, label)
0076     def leave(self,visitee):
0077         pass
0078 
0079 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False,skipLabelTest=False) :
0080     """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
0081     sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly,skipLabelTest=skipLabelTest))
0082 
0083 def massReplaceInputTag(process,old="rawDataCollector",new="rawDataRepacker",verbose=False,moduleLabelOnly=False,skipLabelTest=False):
0084     for s in process.paths_().keys():
0085         massSearchReplaceAnyInputTag(getattr(process,s), old, new, verbose, moduleLabelOnly, skipLabelTest)
0086     for s in process.endpaths_().keys():
0087         massSearchReplaceAnyInputTag(getattr(process,s), old, new, verbose, moduleLabelOnly, skipLabelTest)
0088     if process.schedule_() is not None:
0089         for task in process.schedule_()._tasks:
0090             massSearchReplaceAnyInputTag(task, old, new, verbose, moduleLabelOnly, skipLabelTest)
0091     return(process)
0092 
0093 class MassSearchParamVisitor(object):
0094     """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
0095     def __init__(self,paramName,paramSearch):
0096         self._paramName   = paramName
0097         self._paramSearch = paramSearch
0098         self._modules = []
0099     def enter(self,visitee):
0100         if (hasattr(visitee,self._paramName)):
0101             if getattr(visitee,self._paramName) == self._paramSearch:
0102                 self._modules.append(visitee)
0103     def leave(self,visitee):
0104         pass
0105     def modules(self):
0106         return self._modules
0107 
0108 class MassSearchReplaceParamVisitor(object):
0109     """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
0110     def __init__(self,paramName,paramSearch,paramValue,verbose=False):
0111         self._paramName   = paramName
0112         self._paramValue  = paramValue
0113         self._paramSearch = paramSearch
0114         self._verbose = verbose
0115     def enter(self,visitee):
0116         if isinstance(visitee, cms.SwitchProducer):
0117             for modName in visitee.parameterNames_():
0118                 self.doIt(getattr(visitee, modName), "%s.%s"%(str(visitee), modName))
0119         else:
0120             self.doIt(visitee, str(visitee))
0121     def doIt(self, mod, name):
0122         if (hasattr(mod,self._paramName)):
0123             if getattr(mod,self._paramName) == self._paramSearch:
0124                 if self._verbose:print("Replaced %s.%s: %s => %s" % (name,self._paramName,getattr(mod,self._paramName),self._paramValue))
0125                 setattr(mod,self._paramName,self._paramValue)
0126     def leave(self,visitee):
0127         pass
0128 
0129 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue,verbose=False):
0130     sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue,verbose))
0131 
0132 def massReplaceParameter(process,name="label",old="rawDataCollector",new="rawDataRepacker",verbose=False):
0133     for s in process.paths_().keys():
0134         massSearchReplaceParam(getattr(process,s),name,old,new,verbose)
0135     for s in process.endpaths_().keys():
0136         massSearchReplaceParam(getattr(process,s),name,old,new,verbose)
0137     if process.schedule_() is not None:
0138         for task in process.schedule_()._tasks:
0139             massSearchReplaceParam(task, name, old, new, verbose)
0140     return(process)
0141 
0142 if __name__=="__main__":
0143     import unittest
0144     class SwitchProducerTest(cms.SwitchProducer):
0145         def __init__(self, **kargs):
0146             super(SwitchProducerTest,self).__init__(
0147                 dict(
0148                     test1 = lambda: (True, -10),
0149                     test2 = lambda: (True, -9),
0150                     test3 = lambda: (True, -8),
0151                     test4 = lambda: (True, -7)
0152                 ), **kargs)
0153 
0154     class TestModuleCommand(unittest.TestCase):
0155 
0156         def testMassSearchReplaceAnyInputTag(self):
0157             p = cms.Process("test")
0158             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0159             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0160             p.c = cms.EDProducer("ac", src=cms.InputTag("b"), usrc=cms.untracked.InputTag("b"),
0161                                  nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b")),
0162                                  nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("b")), cms.PSet(src = cms.InputTag("d"))),
0163                                  unestedv = cms.untracked.VPSet(cms.untracked.PSet(src = cms.InputTag("b")), cms.untracked.PSet(src = cms.InputTag("d"))),
0164                                  vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d")),
0165                                  uvec = cms.untracked.VInputTag(cms.untracked.InputTag("a"), cms.untracked.InputTag("b"), cms.untracked.InputTag("c"), cms.untracked.InputTag("d")),
0166                                 )
0167             p.sp = SwitchProducerTest(
0168                 test1 = cms.EDProducer("a", src = cms.InputTag("b"),
0169                                        nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b"))
0170                                        ),
0171                 test2 = cms.EDProducer("b", src = cms.InputTag("c"),
0172                                        nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b"))
0173                                        ),
0174             )
0175             p.op = cms.EDProducer("op", src = cms.optional.InputTag, unset = cms.optional.InputTag, vsrc = cms.optional.VInputTag, vunset = cms.optional.VInputTag)
0176             p.op.src="b"
0177             p.op.vsrc=cms.VInputTag("b")
0178             p.s = cms.Sequence(p.a*p.b*p.c*p.sp*p.op)
0179             massSearchReplaceAnyInputTag(p.s, cms.InputTag("b"), cms.InputTag("new"))
0180             self.assertNotEqual(cms.InputTag("new"), p.b.src)
0181             self.assertEqual(cms.InputTag("new"), p.c.src)
0182             self.assertEqual(cms.InputTag("new"), p.c.usrc)
0183             self.assertEqual(cms.InputTag("new"), p.c.nested.src)
0184             self.assertEqual(cms.InputTag("new"), p.c.nested.usrc)
0185             self.assertFalse(p.c.nested.usrc.isTracked())
0186             self.assertNotEqual(cms.InputTag("new"), p.c.nested.src2)
0187             self.assertEqual(cms.InputTag("new"), p.c.nestedv[0].src)
0188             self.assertNotEqual(cms.InputTag("new"), p.c.nestedv[1].src)
0189             self.assertEqual(cms.InputTag("new"), p.c.unestedv[0].src)
0190             self.assertNotEqual(cms.InputTag("new"), p.c.unestedv[1].src)
0191             self.assertNotEqual(cms.InputTag("new"), p.c.vec[0])
0192             self.assertEqual(cms.InputTag("new"), p.c.vec[1])
0193             self.assertNotEqual(cms.InputTag("new"), p.c.vec[2])
0194             self.assertNotEqual(cms.InputTag("new"), p.c.vec[3])
0195             self.assertNotEqual(cms.InputTag("new"), p.c.uvec[0])
0196             self.assertEqual(cms.InputTag("new"), p.c.uvec[1])
0197             self.assertNotEqual(cms.InputTag("new"), p.c.uvec[2])
0198             self.assertNotEqual(cms.InputTag("new"), p.c.uvec[3])
0199             self.assertFalse(p.c.uvec[0].isTracked())
0200             self.assertFalse(p.c.uvec[1].isTracked())
0201             self.assertFalse(p.c.uvec[2].isTracked())
0202             self.assertFalse(p.c.uvec[3].isTracked())
0203             self.assertEqual(cms.InputTag("new"), p.sp.test1.src)
0204             self.assertEqual(cms.InputTag("new"), p.sp.test1.nested.src)
0205             self.assertEqual(cms.InputTag("c"), p.sp.test1.nested.src2)
0206             self.assertEqual(cms.untracked.InputTag("new"), p.sp.test1.nested.usrc)
0207             self.assertEqual(cms.InputTag("c"), p.sp.test2.src)
0208             self.assertEqual(cms.InputTag("new"), p.sp.test2.nested.src)
0209             self.assertEqual(cms.InputTag("c"), p.sp.test2.nested.src2)
0210             self.assertEqual(cms.untracked.InputTag("new"), p.sp.test2.nested.usrc)
0211             self.assertEqual(cms.InputTag("new"), p.op.src)
0212             self.assertEqual(cms.InputTag("new"), p.op.vsrc[0])
0213 
0214         def testMassReplaceInputTag(self):
0215             process1 = cms.Process("test")
0216             massReplaceInputTag(process1, "a", "b", False, False, False)
0217             self.assertEqual(process1.dumpPython(), cms.Process('test').dumpPython())
0218             p = cms.Process("test")
0219             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0220             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0221             p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0222                                  nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a")),
0223                                  nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("a")), cms.PSet(src = cms.InputTag("d"))),
0224                                  unestedv = cms.untracked.VPSet(cms.untracked.PSet(src = cms.InputTag("a")), cms.untracked.PSet(src = cms.InputTag("d"))),
0225                                  vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d")),
0226                                  uvec = cms.untracked.VInputTag(cms.untracked.InputTag("a"), cms.untracked.InputTag("b"), cms.untracked.InputTag("c"), cms.InputTag("d")),
0227                                 )
0228             p.d = cms.EDProducer("ab", src=cms.InputTag("a"))
0229             p.e = cms.EDProducer("ab", src=cms.InputTag("a"))
0230             p.f = cms.EDProducer("ab", src=cms.InputTag("a"))
0231             p.g = cms.EDProducer("ab", src=cms.InputTag("a"))
0232             p.h = cms.EDProducer("ab", src=cms.InputTag("a"))
0233             p.i = cms.EDProducer("ab", src=cms.InputTag("a"))
0234             p.sp = SwitchProducerTest(
0235                 test1 = cms.EDProducer("a", src = cms.InputTag("a"),
0236                                        nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a"))
0237                                        ),
0238                 test2 = cms.EDProducer("b", src = cms.InputTag("c"),
0239                                        nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a"))
0240                                        ),
0241             )
0242             p.s1 = cms.Sequence(p.a*p.b*p.c*p.sp)
0243             p.path1 = cms.Path(p.s1)
0244             p.s2 = cms.Sequence(p.d)
0245             p.path2 = cms.Path(p.e)
0246             p.s3 = cms.Sequence(p.f)
0247             p.endpath1 = cms.EndPath(p.s3)
0248             p.endpath2 = cms.EndPath(p.g)
0249             p.t1 = cms.Task(p.h)
0250             p.t2 = cms.Task(p.i)
0251             p.schedule = cms.Schedule()
0252             p.schedule.associate(p.t1, p.t2)
0253             massReplaceInputTag(p, "a", "b", False, False, False)
0254             self.assertEqual(cms.InputTag("b"), p.b.src)
0255             self.assertEqual(cms.InputTag("b"), p.c.nested.src)
0256             self.assertEqual(cms.InputTag("b"), p.c.nested.usrc)
0257             self.assertFalse(p.c.nested.usrc.isTracked())
0258             self.assertEqual(cms.InputTag("b"), p.c.nestedv[0].src)
0259             self.assertEqual(cms.InputTag("b"), p.c.unestedv[0].src)
0260             self.assertEqual(cms.InputTag("b"), p.c.vec[0])
0261             self.assertEqual(cms.InputTag("c"), p.c.vec[2])
0262             self.assertEqual(cms.InputTag("b"), p.c.uvec[0])
0263             self.assertEqual(cms.InputTag("c"), p.c.uvec[2])
0264             self.assertFalse(p.c.uvec[0].isTracked())
0265             self.assertFalse(p.c.uvec[1].isTracked())
0266             self.assertFalse(p.c.uvec[2].isTracked())
0267             self.assertEqual(cms.InputTag("a"), p.d.src)
0268             self.assertEqual(cms.InputTag("b"), p.e.src)
0269             self.assertEqual(cms.InputTag("b"), p.f.src)
0270             self.assertEqual(cms.InputTag("b"), p.g.src)
0271             self.assertEqual(cms.InputTag("b"), p.h.src)
0272             self.assertEqual(cms.InputTag("b"), p.i.src)
0273             self.assertEqual(cms.InputTag("b"), p.sp.test1.src)
0274             self.assertEqual(cms.InputTag("b"), p.sp.test1.nested.src)
0275             self.assertEqual(cms.InputTag("c"), p.sp.test1.nested.src2)
0276             self.assertEqual(cms.untracked.InputTag("b"), p.sp.test1.nested.usrc)
0277             self.assertEqual(cms.InputTag("c"), p.sp.test2.src)
0278             self.assertEqual(cms.InputTag("b"), p.sp.test2.nested.src)
0279             self.assertEqual(cms.InputTag("c"), p.sp.test2.nested.src2)
0280             self.assertEqual(cms.untracked.InputTag("b"), p.sp.test2.nested.usrc)
0281 
0282         def testMassSearchReplaceParam(self):
0283             p = cms.Process("test")
0284             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0285             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0286             p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0287                                  nested = cms.PSet(src = cms.InputTag("c"), src2 = cms.InputTag("b"))
0288                                 )
0289             p.d = cms.EDProducer("ac", src=cms.untracked.InputTag("b"),
0290                                  nested = cms.PSet(src = cms.InputTag("c"), src2 = cms.InputTag("b"))
0291                                 )
0292             p.sp = SwitchProducerTest(
0293                 test1 = cms.EDProducer("a", src = cms.InputTag("b"),
0294                                        nested = cms.PSet(src = cms.InputTag("b"))
0295                                        ),
0296                 test2 = cms.EDProducer("b", src = cms.InputTag("b")),
0297             )
0298             p.s = cms.Sequence(p.a*p.b*p.c*p.d*p.sp)
0299             massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
0300             self.assertEqual(cms.InputTag("a"),p.c.src)
0301             self.assertEqual(cms.InputTag("c"),p.c.nested.src)
0302             self.assertEqual(cms.InputTag("b"),p.c.nested.src2)
0303             self.assertEqual(cms.untracked.InputTag("a"),p.d.src)
0304             self.assertEqual(cms.InputTag("c"),p.d.nested.src)
0305             self.assertEqual(cms.InputTag("b"),p.d.nested.src2)
0306             self.assertEqual(cms.InputTag("a"),p.sp.test1.src)
0307             self.assertEqual(cms.InputTag("b"),p.sp.test1.nested.src)
0308             self.assertEqual(cms.InputTag("a"),p.sp.test2.src)
0309 
0310         def testMassReplaceParam(self):
0311             process1 = cms.Process("test")
0312             massReplaceParameter(process1, "src", cms.InputTag("a"), "b", False)
0313             self.assertEqual(process1.dumpPython(), cms.Process("test").dumpPython())
0314             p = cms.Process("test")
0315             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0316             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0317             p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0318                                  nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c")),
0319                                  nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("a")), cms.PSet(src = cms.InputTag("d"))),
0320                                  vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d"))
0321                                 )
0322             p.d = cms.EDProducer("ab", src=cms.InputTag("a"))
0323             p.e = cms.EDProducer("ab", src=cms.InputTag("a"))
0324             p.f = cms.EDProducer("ab", src=cms.InputTag("a"))
0325             p.g = cms.EDProducer("ab", src=cms.InputTag("a"))
0326             p.h = cms.EDProducer("ab", src=cms.InputTag("a"))
0327             p.i = cms.EDProducer("ab", src=cms.InputTag("a"))
0328             p.j = cms.EDProducer("ab", src=cms.untracked.InputTag("a"))
0329             p.sp = SwitchProducerTest(
0330                 test1 = cms.EDProducer("a", src = cms.InputTag("a"),
0331                                        nested = cms.PSet(src = cms.InputTag("a"))
0332                                        ),
0333                 test2 = cms.EDProducer("b", src = cms.InputTag("a")),
0334             )
0335             p.s1 = cms.Sequence(p.a*p.b*p.c*p.sp)
0336             p.path1 = cms.Path(p.s1)
0337             p.s2 = cms.Sequence(p.d)
0338             p.path2 = cms.Path(p.e)
0339             p.s3 = cms.Sequence(p.f)
0340             p.endpath1 = cms.EndPath(p.s3)
0341             p.endpath2 = cms.EndPath(p.g)
0342             p.t1 = cms.Task(p.h)
0343             p.t2 = cms.Task(p.i, p.j)
0344             p.schedule = cms.Schedule()
0345             p.schedule.associate(p.t1, p.t2)
0346             massReplaceParameter(p, "src",cms.InputTag("a"), "b", False)
0347             self.assertEqual(cms.InputTag("gen"), p.a.src)
0348             self.assertEqual(cms.InputTag("b"), p.b.src)
0349             self.assertEqual(cms.InputTag("a"), p.c.vec[0])
0350             self.assertEqual(cms.InputTag("c"), p.c.vec[2])
0351             self.assertEqual(cms.InputTag("a"), p.d.src)
0352             self.assertEqual(cms.InputTag("b"), p.e.src)
0353             self.assertEqual(cms.InputTag("b"), p.f.src)
0354             self.assertEqual(cms.InputTag("b"), p.g.src)
0355             self.assertEqual(cms.InputTag("b"), p.h.src)
0356             self.assertEqual(cms.InputTag("b"), p.i.src)
0357             self.assertEqual(cms.untracked.InputTag("b"), p.j.src)
0358             self.assertEqual(cms.InputTag("b"),p.sp.test1.src)
0359             self.assertEqual(cms.InputTag("a"),p.sp.test1.nested.src)
0360             self.assertEqual(cms.InputTag("b"),p.sp.test2.src)
0361     unittest.main()