Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:12:53

0001 from __future__ import print_function
0002 import FWCore.ParameterSet.Config as cms
0003 
0004 class MassSearchReplaceAnyInputTagVisitor(object):
0005     """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
0006        It will climb down within PSets, VPSets and VInputTags to find its target"""
0007     def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False,skipLabelTest=False):
0008         self._paramSearch  = self.standardizeInputTagFmt(paramSearch)
0009         self._paramReplace = self.standardizeInputTagFmt(paramReplace)
0010         self._moduleName   = ''
0011         self._verbose=verbose
0012         self._moduleLabelOnly=moduleLabelOnly
0013         self._skipLabelTest=skipLabelTest
0014     def doIt(self,pset,base):
0015         if isinstance(pset, cms._Parameterizable):
0016             for name in pset.parameterNames_():
0017                 # if I use pset.parameters_().items() I get copies of the parameter values
0018                 # so I can't modify the nested pset
0019                 value = getattr(pset,name)
0020                 if isinstance(value, cms.PSet) or isinstance(value, cms.EDProducer) or isinstance(value, cms.EDAlias):
0021                     # EDProducer and EDAlias to support SwitchProducer
0022                     self.doIt(value,base+"."+name)
0023                 elif value.isCompatibleCMSType(cms.VPSet):
0024                     for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
0025                 elif value.isCompatibleCMSType(cms.VInputTag) and value:
0026                     for (i,n) in enumerate(value):
0027                         # VInputTag can be declared as a list of strings, so ensure that n is formatted correctly
0028                         n = self.standardizeInputTagFmt(n)
0029                         if (n == self._paramSearch):
0030                             if self._verbose:print("Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace))
0031                             if not value.isTracked():
0032                                 value[i] = cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
0033                                                                   self._paramReplace.getProductInstanceLabel(),
0034                                                                   self._paramReplace.getProcessName())
0035                             else:
0036                                 value[i] = self._paramReplace
0037                         elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
0038                             nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
0039                             if self._verbose:print("Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep))
0040                             value[i] = nrep
0041                 elif value.isCompatibleCMSType(cms.InputTag) and value:
0042                     if value == self._paramSearch:
0043                         if self._verbose:print("Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace))
0044                         from copy import deepcopy
0045                         if not value.isTracked():
0046                             # the existing value should stay untracked even if the given parameter is tracked
0047                             setattr(pset, name, cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
0048                                                                        self._paramReplace.getProductInstanceLabel(),
0049                                                                        self._paramReplace.getProcessName()))
0050                         else:
0051                             setattr(pset, name, deepcopy(self._paramReplace) )
0052                     elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
0053                         from copy import deepcopy
0054                         repl = deepcopy(getattr(pset, name))
0055                         repl.moduleLabel = self._paramReplace.moduleLabel
0056                         setattr(pset, name, repl)
0057                         if self._verbose:print("Replace %s.%s %s ==> %s " % (base, name, value, repl))
0058 
0059 
0060     @staticmethod
0061     def standardizeInputTagFmt(inputTag):
0062         ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
0063         if not isinstance(inputTag, cms.InputTag):
0064             return cms.InputTag(inputTag)
0065         return inputTag
0066 
0067     def enter(self,visitee):
0068         label = ''
0069         if (not self._skipLabelTest):
0070             if hasattr(visitee,"hasLabel_") and visitee.hasLabel_():
0071                 label = visitee.label_()
0072             else: label = '<Module not in a Process>'
0073         else:
0074             label = '<Module label not tested>'
0075         self.doIt(visitee, label)
0076     def leave(self,visitee):
0077         pass
0078 
0079 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False,skipLabelTest=False) :
0080     """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
0081     sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly,skipLabelTest=skipLabelTest))
0082 
0083 def massReplaceInputTag(process,old="rawDataCollector",new="rawDataRepacker",verbose=False,moduleLabelOnly=False,skipLabelTest=False):
0084     for s in process.paths_().keys():
0085         massSearchReplaceAnyInputTag(getattr(process,s), old, new, verbose, moduleLabelOnly, skipLabelTest)
0086     for s in process.endpaths_().keys():
0087         massSearchReplaceAnyInputTag(getattr(process,s), old, new, verbose, moduleLabelOnly, skipLabelTest)
0088     if process.schedule_() is not None:
0089         for task in process.schedule_()._tasks:
0090             massSearchReplaceAnyInputTag(task, old, new, verbose, moduleLabelOnly, skipLabelTest)
0091     return(process)
0092 
0093 class MassSearchParamVisitor(object):
0094     """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
0095     def __init__(self,paramName,paramSearch):
0096         self._paramName   = paramName
0097         self._paramSearch = paramSearch
0098         self._modules = []
0099     def enter(self,visitee):
0100         if (hasattr(visitee,self._paramName)):
0101             if getattr(visitee,self._paramName) == self._paramSearch:
0102                 self._modules.append(visitee)
0103     def leave(self,visitee):
0104         pass
0105     def modules(self):
0106         return self._modules
0107 
0108 class MassSearchReplaceParamVisitor(object):
0109     """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
0110     def __init__(self,paramName,paramSearch,paramValue,verbose=False):
0111         self._paramName   = paramName
0112         self._paramValue  = paramValue
0113         self._paramSearch = paramSearch
0114         self._verbose = verbose
0115     def enter(self,visitee):
0116         if isinstance(visitee, cms.SwitchProducer):
0117             for modName in visitee.parameterNames_():
0118                 self.doIt(getattr(visitee, modName), "%s.%s"%(str(visitee), modName))
0119         else:
0120             self.doIt(visitee, str(visitee))
0121     def doIt(self, mod, name):
0122         if (hasattr(mod,self._paramName)):
0123             if getattr(mod,self._paramName) == self._paramSearch:
0124                 if self._verbose:print("Replaced %s.%s: %s => %s" % (name,self._paramName,getattr(mod,self._paramName),self._paramValue))
0125                 setattr(mod,self._paramName,self._paramValue)
0126     def leave(self,visitee):
0127         pass
0128 
0129 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue,verbose=False):
0130     sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue,verbose))
0131 
0132 def massReplaceParameter(process,name="label",old="rawDataCollector",new="rawDataRepacker",verbose=False):
0133     for s in process.paths_().keys():
0134         massSearchReplaceParam(getattr(process,s),name,old,new,verbose)
0135     for s in process.endpaths_().keys():
0136         massSearchReplaceParam(getattr(process,s),name,old,new,verbose)
0137     if process.schedule_() is not None:
0138         for task in process.schedule_()._tasks:
0139             massSearchReplaceParam(task, name, old, new, verbose)
0140     return(process)
0141 
0142 if __name__=="__main__":
0143     import unittest
0144     class SwitchProducerTest(cms.SwitchProducer):
0145         def __init__(self, **kargs):
0146             super(SwitchProducerTest,self).__init__(
0147                 dict(
0148                     test1 = lambda: (True, -10),
0149                     test2 = lambda: (True, -9),
0150                     test3 = lambda: (True, -8),
0151                     test4 = lambda: (True, -7)
0152                 ), **kargs)
0153 
0154     class TestModuleCommand(unittest.TestCase):
0155 
0156         def testMassSearchReplaceAnyInputTag(self):
0157             p = cms.Process("test")
0158             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0159             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0160             p.c = cms.EDProducer("ac", src=cms.InputTag("b"), usrc=cms.untracked.InputTag("b"),
0161                                  nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b")),
0162                                  nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("b")), cms.PSet(src = cms.InputTag("d"))),
0163                                  unestedv = cms.untracked.VPSet(cms.untracked.PSet(src = cms.InputTag("b")), cms.untracked.PSet(src = cms.InputTag("d"))),
0164                                  vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d")),
0165                                  uvec = cms.untracked.VInputTag(cms.untracked.InputTag("a"), cms.untracked.InputTag("b"), cms.untracked.InputTag("c"), cms.untracked.InputTag("d")),
0166                                 )
0167             p.sp = SwitchProducerTest(
0168                 test1 = cms.EDProducer("a", src = cms.InputTag("b"),
0169                                        nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b"))
0170                                        ),
0171                 test2 = cms.EDProducer("b", src = cms.InputTag("c"),
0172                                        nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b"))
0173                                        ),
0174             )
0175             p.op = cms.EDProducer("op", src = cms.optional.InputTag, unset = cms.optional.InputTag, vsrc = cms.optional.VInputTag, vunset = cms.optional.VInputTag)
0176             p.op2 = cms.EDProducer("op2", src = cms.optional.InputTag, unset = cms.optional.InputTag, vsrc = cms.optional.VInputTag, vunset = cms.optional.VInputTag)
0177             p.op.src="b"
0178             p.op.vsrc = ["b"]
0179             p.op2.src=cms.InputTag("b")
0180             p.op2.vsrc = cms.VInputTag("b")
0181             p.s = cms.Sequence(p.a*p.b*p.c*p.sp*p.op*p.op2)
0182             massSearchReplaceAnyInputTag(p.s, cms.InputTag("b"), cms.InputTag("new"))
0183             self.assertNotEqual(cms.InputTag("new"), p.b.src)
0184             self.assertEqual(cms.InputTag("new"), p.c.src)
0185             self.assertEqual(cms.InputTag("new"), p.c.usrc)
0186             self.assertEqual(cms.InputTag("new"), p.c.nested.src)
0187             self.assertEqual(cms.InputTag("new"), p.c.nested.usrc)
0188             self.assertFalse(p.c.nested.usrc.isTracked())
0189             self.assertNotEqual(cms.InputTag("new"), p.c.nested.src2)
0190             self.assertEqual(cms.InputTag("new"), p.c.nestedv[0].src)
0191             self.assertNotEqual(cms.InputTag("new"), p.c.nestedv[1].src)
0192             self.assertEqual(cms.InputTag("new"), p.c.unestedv[0].src)
0193             self.assertNotEqual(cms.InputTag("new"), p.c.unestedv[1].src)
0194             self.assertNotEqual(cms.InputTag("new"), p.c.vec[0])
0195             self.assertEqual(cms.InputTag("new"), p.c.vec[1])
0196             self.assertNotEqual(cms.InputTag("new"), p.c.vec[2])
0197             self.assertNotEqual(cms.InputTag("new"), p.c.vec[3])
0198             self.assertNotEqual(cms.InputTag("new"), p.c.uvec[0])
0199             self.assertEqual(cms.InputTag("new"), p.c.uvec[1])
0200             self.assertNotEqual(cms.InputTag("new"), p.c.uvec[2])
0201             self.assertNotEqual(cms.InputTag("new"), p.c.uvec[3])
0202             self.assertFalse(p.c.uvec[0].isTracked())
0203             self.assertFalse(p.c.uvec[1].isTracked())
0204             self.assertFalse(p.c.uvec[2].isTracked())
0205             self.assertFalse(p.c.uvec[3].isTracked())
0206             self.assertEqual(cms.InputTag("new"), p.sp.test1.src)
0207             self.assertEqual(cms.InputTag("new"), p.sp.test1.nested.src)
0208             self.assertEqual(cms.InputTag("c"), p.sp.test1.nested.src2)
0209             self.assertEqual(cms.untracked.InputTag("new"), p.sp.test1.nested.usrc)
0210             self.assertEqual(cms.InputTag("c"), p.sp.test2.src)
0211             self.assertEqual(cms.InputTag("new"), p.sp.test2.nested.src)
0212             self.assertEqual(cms.InputTag("c"), p.sp.test2.nested.src2)
0213             self.assertEqual(cms.untracked.InputTag("new"), p.sp.test2.nested.usrc)
0214             self.assertEqual(cms.InputTag("new"), p.op.src)
0215             self.assertEqual(cms.InputTag("new"), p.op.vsrc[0])
0216             self.assertEqual(cms.InputTag("new"), p.op2.src)
0217             self.assertEqual(cms.InputTag("new"), p.op2.vsrc[0])
0218 
0219         def testMassReplaceInputTag(self):
0220             process1 = cms.Process("test")
0221             massReplaceInputTag(process1, "a", "b", False, False, False)
0222             self.assertEqual(process1.dumpPython(), cms.Process('test').dumpPython())
0223             p = cms.Process("test")
0224             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0225             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0226             p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0227                                  nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a")),
0228                                  nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("a")), cms.PSet(src = cms.InputTag("d"))),
0229                                  unestedv = cms.untracked.VPSet(cms.untracked.PSet(src = cms.InputTag("a")), cms.untracked.PSet(src = cms.InputTag("d"))),
0230                                  vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d")),
0231                                  uvec = cms.untracked.VInputTag(cms.untracked.InputTag("a"), cms.untracked.InputTag("b"), cms.untracked.InputTag("c"), cms.InputTag("d")),
0232                                 )
0233             p.d = cms.EDProducer("ab", src=cms.InputTag("a"))
0234             p.e = cms.EDProducer("ab", src=cms.InputTag("a"))
0235             p.f = cms.EDProducer("ab", src=cms.InputTag("a"))
0236             p.g = cms.EDProducer("ab", src=cms.InputTag("a"))
0237             p.h = cms.EDProducer("ab", src=cms.InputTag("a"))
0238             p.i = cms.EDProducer("ab", src=cms.InputTag("a"))
0239             p.sp = SwitchProducerTest(
0240                 test1 = cms.EDProducer("a", src = cms.InputTag("a"),
0241                                        nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a"))
0242                                        ),
0243                 test2 = cms.EDProducer("b", src = cms.InputTag("c"),
0244                                        nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a"))
0245                                        ),
0246             )
0247             p.s1 = cms.Sequence(p.a*p.b*p.c*p.sp)
0248             p.path1 = cms.Path(p.s1)
0249             p.s2 = cms.Sequence(p.d)
0250             p.path2 = cms.Path(p.e)
0251             p.s3 = cms.Sequence(p.f)
0252             p.endpath1 = cms.EndPath(p.s3)
0253             p.endpath2 = cms.EndPath(p.g)
0254             p.t1 = cms.Task(p.h)
0255             p.t2 = cms.Task(p.i)
0256             p.schedule = cms.Schedule()
0257             p.schedule.associate(p.t1, p.t2)
0258             massReplaceInputTag(p, "a", "b", False, False, False)
0259             self.assertEqual(cms.InputTag("b"), p.b.src)
0260             self.assertEqual(cms.InputTag("b"), p.c.nested.src)
0261             self.assertEqual(cms.InputTag("b"), p.c.nested.usrc)
0262             self.assertFalse(p.c.nested.usrc.isTracked())
0263             self.assertEqual(cms.InputTag("b"), p.c.nestedv[0].src)
0264             self.assertEqual(cms.InputTag("b"), p.c.unestedv[0].src)
0265             self.assertEqual(cms.InputTag("b"), p.c.vec[0])
0266             self.assertEqual(cms.InputTag("c"), p.c.vec[2])
0267             self.assertEqual(cms.InputTag("b"), p.c.uvec[0])
0268             self.assertEqual(cms.InputTag("c"), p.c.uvec[2])
0269             self.assertFalse(p.c.uvec[0].isTracked())
0270             self.assertFalse(p.c.uvec[1].isTracked())
0271             self.assertFalse(p.c.uvec[2].isTracked())
0272             self.assertEqual(cms.InputTag("a"), p.d.src)
0273             self.assertEqual(cms.InputTag("b"), p.e.src)
0274             self.assertEqual(cms.InputTag("b"), p.f.src)
0275             self.assertEqual(cms.InputTag("b"), p.g.src)
0276             self.assertEqual(cms.InputTag("b"), p.h.src)
0277             self.assertEqual(cms.InputTag("b"), p.i.src)
0278             self.assertEqual(cms.InputTag("b"), p.sp.test1.src)
0279             self.assertEqual(cms.InputTag("b"), p.sp.test1.nested.src)
0280             self.assertEqual(cms.InputTag("c"), p.sp.test1.nested.src2)
0281             self.assertEqual(cms.untracked.InputTag("b"), p.sp.test1.nested.usrc)
0282             self.assertEqual(cms.InputTag("c"), p.sp.test2.src)
0283             self.assertEqual(cms.InputTag("b"), p.sp.test2.nested.src)
0284             self.assertEqual(cms.InputTag("c"), p.sp.test2.nested.src2)
0285             self.assertEqual(cms.untracked.InputTag("b"), p.sp.test2.nested.usrc)
0286 
0287         def testMassSearchReplaceParam(self):
0288             p = cms.Process("test")
0289             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0290             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0291             p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0292                                  nested = cms.PSet(src = cms.InputTag("c"), src2 = cms.InputTag("b"))
0293                                 )
0294             p.d = cms.EDProducer("ac", src=cms.untracked.InputTag("b"),
0295                                  nested = cms.PSet(src = cms.InputTag("c"), src2 = cms.InputTag("b"))
0296                                 )
0297             p.sp = SwitchProducerTest(
0298                 test1 = cms.EDProducer("a", src = cms.InputTag("b"),
0299                                        nested = cms.PSet(src = cms.InputTag("b"))
0300                                        ),
0301                 test2 = cms.EDProducer("b", src = cms.InputTag("b")),
0302             )
0303             p.s = cms.Sequence(p.a*p.b*p.c*p.d*p.sp)
0304             massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
0305             self.assertEqual(cms.InputTag("a"),p.c.src)
0306             self.assertEqual(cms.InputTag("c"),p.c.nested.src)
0307             self.assertEqual(cms.InputTag("b"),p.c.nested.src2)
0308             self.assertEqual(cms.untracked.InputTag("a"),p.d.src)
0309             self.assertEqual(cms.InputTag("c"),p.d.nested.src)
0310             self.assertEqual(cms.InputTag("b"),p.d.nested.src2)
0311             self.assertEqual(cms.InputTag("a"),p.sp.test1.src)
0312             self.assertEqual(cms.InputTag("b"),p.sp.test1.nested.src)
0313             self.assertEqual(cms.InputTag("a"),p.sp.test2.src)
0314 
0315         def testMassReplaceParam(self):
0316             process1 = cms.Process("test")
0317             massReplaceParameter(process1, "src", cms.InputTag("a"), "b", False)
0318             self.assertEqual(process1.dumpPython(), cms.Process("test").dumpPython())
0319             p = cms.Process("test")
0320             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0321             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0322             p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0323                                  nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c")),
0324                                  nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("a")), cms.PSet(src = cms.InputTag("d"))),
0325                                  vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d"))
0326                                 )
0327             p.d = cms.EDProducer("ab", src=cms.InputTag("a"))
0328             p.e = cms.EDProducer("ab", src=cms.InputTag("a"))
0329             p.f = cms.EDProducer("ab", src=cms.InputTag("a"))
0330             p.g = cms.EDProducer("ab", src=cms.InputTag("a"))
0331             p.h = cms.EDProducer("ab", src=cms.InputTag("a"))
0332             p.i = cms.EDProducer("ab", src=cms.InputTag("a"))
0333             p.j = cms.EDProducer("ab", src=cms.untracked.InputTag("a"))
0334             p.sp = SwitchProducerTest(
0335                 test1 = cms.EDProducer("a", src = cms.InputTag("a"),
0336                                        nested = cms.PSet(src = cms.InputTag("a"))
0337                                        ),
0338                 test2 = cms.EDProducer("b", src = cms.InputTag("a")),
0339             )
0340             p.s1 = cms.Sequence(p.a*p.b*p.c*p.sp)
0341             p.path1 = cms.Path(p.s1)
0342             p.s2 = cms.Sequence(p.d)
0343             p.path2 = cms.Path(p.e)
0344             p.s3 = cms.Sequence(p.f)
0345             p.endpath1 = cms.EndPath(p.s3)
0346             p.endpath2 = cms.EndPath(p.g)
0347             p.t1 = cms.Task(p.h)
0348             p.t2 = cms.Task(p.i, p.j)
0349             p.schedule = cms.Schedule()
0350             p.schedule.associate(p.t1, p.t2)
0351             massReplaceParameter(p, "src",cms.InputTag("a"), "b", False)
0352             self.assertEqual(cms.InputTag("gen"), p.a.src)
0353             self.assertEqual(cms.InputTag("b"), p.b.src)
0354             self.assertEqual(cms.InputTag("a"), p.c.vec[0])
0355             self.assertEqual(cms.InputTag("c"), p.c.vec[2])
0356             self.assertEqual(cms.InputTag("a"), p.d.src)
0357             self.assertEqual(cms.InputTag("b"), p.e.src)
0358             self.assertEqual(cms.InputTag("b"), p.f.src)
0359             self.assertEqual(cms.InputTag("b"), p.g.src)
0360             self.assertEqual(cms.InputTag("b"), p.h.src)
0361             self.assertEqual(cms.InputTag("b"), p.i.src)
0362             self.assertEqual(cms.untracked.InputTag("b"), p.j.src)
0363             self.assertEqual(cms.InputTag("b"),p.sp.test1.src)
0364             self.assertEqual(cms.InputTag("a"),p.sp.test1.nested.src)
0365             self.assertEqual(cms.InputTag("b"),p.sp.test2.src)
0366     unittest.main()