Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-12-01 23:40:19

0001 import FWCore.ParameterSet.Config as cms
0002 
0003 class MassSearchReplaceAnyInputTagVisitor(object):
0004     """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
0005        It will climb down within PSets, VPSets and VInputTags to find its target"""
0006     def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False,skipLabelTest=False):
0007         self._paramSearch  = self.standardizeInputTagFmt(paramSearch)
0008         self._paramReplace = self.standardizeInputTagFmt(paramReplace)
0009         self._moduleName   = ''
0010         self._verbose=verbose
0011         self._moduleLabelOnly=moduleLabelOnly
0012         self._skipLabelTest=skipLabelTest
0013     def doIt(self,pset,base):
0014         if isinstance(pset, cms._Parameterizable):
0015             for name in pset.parameterNames_():
0016                 # if I use pset.parameters_().items() I get copies of the parameter values
0017                 # so I can't modify the nested pset
0018                 value = getattr(pset,name)
0019                 if isinstance(value, cms.PSet) or isinstance(value, cms.EDProducer) or isinstance(value, cms.EDAlias):
0020                     # EDProducer and EDAlias to support SwitchProducer
0021                     self.doIt(value,base+"."+name)
0022                 elif value.isCompatibleCMSType(cms.VPSet):
0023                     for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
0024                 elif value.isCompatibleCMSType(cms.VInputTag) and value:
0025                     for (i,n) in enumerate(value):
0026                         # VInputTag can be declared as a list of strings, so ensure that n is formatted correctly
0027                         n = self.standardizeInputTagFmt(n)
0028                         if (n == self._paramSearch):
0029                             if self._verbose:print("Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace))
0030                             if not value.isTracked():
0031                                 value[i] = cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
0032                                                                   self._paramReplace.getProductInstanceLabel(),
0033                                                                   self._paramReplace.getProcessName())
0034                             else:
0035                                 value[i] = self._paramReplace
0036                         elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
0037                             nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
0038                             if self._verbose:print("Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep))
0039                             value[i] = nrep
0040                 elif value.isCompatibleCMSType(cms.InputTag) and value:
0041                     if value == self._paramSearch:
0042                         if self._verbose:print("Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace))
0043                         from copy import deepcopy
0044                         if not value.isTracked():
0045                             # the existing value should stay untracked even if the given parameter is tracked
0046                             setattr(pset, name, cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
0047                                                                        self._paramReplace.getProductInstanceLabel(),
0048                                                                        self._paramReplace.getProcessName()))
0049                         else:
0050                             setattr(pset, name, deepcopy(self._paramReplace) )
0051                     elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
0052                         from copy import deepcopy
0053                         repl = deepcopy(getattr(pset, name))
0054                         repl.moduleLabel = self._paramReplace.moduleLabel
0055                         setattr(pset, name, repl)
0056                         if self._verbose:print("Replace %s.%s %s ==> %s " % (base, name, value, repl))
0057 
0058 
0059     @staticmethod
0060     def standardizeInputTagFmt(inputTag):
0061         ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
0062         if not isinstance(inputTag, cms.InputTag):
0063             return cms.InputTag(inputTag)
0064         return inputTag
0065 
0066     def enter(self,visitee):
0067         label = ''
0068         if (not self._skipLabelTest):
0069             if hasattr(visitee,"hasLabel_") and visitee.hasLabel_():
0070                 label = visitee.label_()
0071             else: label = '<Module not in a Process>'
0072         else:
0073             label = '<Module label not tested>'
0074         self.doIt(visitee, label)
0075     def leave(self,visitee):
0076         pass
0077 
0078 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False,skipLabelTest=False) :
0079     """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
0080     sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly,skipLabelTest=skipLabelTest))
0081 
0082 def massReplaceInputTag(process,old="rawDataCollector",new="rawDataRepacker",verbose=False,moduleLabelOnly=False,skipLabelTest=False):
0083     for s in process.paths_().keys():
0084         massSearchReplaceAnyInputTag(getattr(process,s), old, new, verbose, moduleLabelOnly, skipLabelTest)
0085     for s in process.endpaths_().keys():
0086         massSearchReplaceAnyInputTag(getattr(process,s), old, new, verbose, moduleLabelOnly, skipLabelTest)
0087     if process.schedule_() is not None:
0088         for task in process.schedule_()._tasks:
0089             massSearchReplaceAnyInputTag(task, old, new, verbose, moduleLabelOnly, skipLabelTest)
0090     return(process)
0091 
0092 class MassSearchParamVisitor(object):
0093     """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
0094     def __init__(self,paramName,paramSearch):
0095         self._paramName   = paramName
0096         self._paramSearch = paramSearch
0097         self._modules = []
0098     def enter(self,visitee):
0099         if (hasattr(visitee,self._paramName)):
0100             if getattr(visitee,self._paramName) == self._paramSearch:
0101                 self._modules.append(visitee)
0102     def leave(self,visitee):
0103         pass
0104     def modules(self):
0105         return self._modules
0106 
0107 class MassSearchReplaceParamVisitor(object):
0108     """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
0109     def __init__(self,paramName,paramSearch,paramValue,verbose=False):
0110         self._paramName   = paramName
0111         self._paramValue  = paramValue
0112         self._paramSearch = paramSearch
0113         self._verbose = verbose
0114     def enter(self,visitee):
0115         if isinstance(visitee, cms.SwitchProducer):
0116             for modName in visitee.parameterNames_():
0117                 self.doIt(getattr(visitee, modName), "%s.%s"%(str(visitee), modName))
0118         else:
0119             self.doIt(visitee, str(visitee))
0120     def doIt(self, mod, name):
0121         if (hasattr(mod,self._paramName)):
0122             if getattr(mod,self._paramName) == self._paramSearch:
0123                 if self._verbose:print("Replaced %s.%s: %s => %s" % (name,self._paramName,getattr(mod,self._paramName),self._paramValue))
0124                 setattr(mod,self._paramName,self._paramValue)
0125     def leave(self,visitee):
0126         pass
0127 
0128 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue,verbose=False):
0129     sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue,verbose))
0130 
0131 def massReplaceParameter(process,name="label",old="rawDataCollector",new="rawDataRepacker",verbose=False):
0132     for s in process.paths_().keys():
0133         massSearchReplaceParam(getattr(process,s),name,old,new,verbose)
0134     for s in process.endpaths_().keys():
0135         massSearchReplaceParam(getattr(process,s),name,old,new,verbose)
0136     if process.schedule_() is not None:
0137         for task in process.schedule_()._tasks:
0138             massSearchReplaceParam(task, name, old, new, verbose)
0139     return(process)
0140 
0141 if __name__=="__main__":
0142     import unittest
0143     class SwitchProducerTest(cms.SwitchProducer):
0144         def __init__(self, **kargs):
0145             super(SwitchProducerTest,self).__init__(
0146                 dict(
0147                     test1 = lambda: (True, -10),
0148                     test2 = lambda: (True, -9),
0149                     test3 = lambda: (True, -8),
0150                     test4 = lambda: (True, -7)
0151                 ), **kargs)
0152 
0153     class TestModuleCommand(unittest.TestCase):
0154 
0155         def testMassSearchReplaceAnyInputTag(self):
0156             p = cms.Process("test")
0157             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0158             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0159             p.c = cms.EDProducer("ac", src=cms.InputTag("b"), usrc=cms.untracked.InputTag("b"),
0160                                  nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b")),
0161                                  nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("b")), cms.PSet(src = cms.InputTag("d"))),
0162                                  unestedv = cms.untracked.VPSet(cms.untracked.PSet(src = cms.InputTag("b")), cms.untracked.PSet(src = cms.InputTag("d"))),
0163                                  vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d")),
0164                                  uvec = cms.untracked.VInputTag(cms.untracked.InputTag("a"), cms.untracked.InputTag("b"), cms.untracked.InputTag("c"), cms.untracked.InputTag("d")),
0165                                 )
0166             p.sp = SwitchProducerTest(
0167                 test1 = cms.EDProducer("a", src = cms.InputTag("b"),
0168                                        nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b"))
0169                                        ),
0170                 test2 = cms.EDProducer("b", src = cms.InputTag("c"),
0171                                        nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b"))
0172                                        ),
0173             )
0174             p.op = cms.EDProducer("op", src = cms.optional.InputTag, unset = cms.optional.InputTag, vsrc = cms.optional.VInputTag, vunset = cms.optional.VInputTag)
0175             p.op2 = cms.EDProducer("op2", src = cms.optional.InputTag, unset = cms.optional.InputTag, vsrc = cms.optional.VInputTag, vunset = cms.optional.VInputTag)
0176             p.op.src="b"
0177             p.op.vsrc = ["b"]
0178             p.op2.src=cms.InputTag("b")
0179             p.op2.vsrc = cms.VInputTag("b")
0180             p.s = cms.Sequence(p.a*p.b*p.c*p.sp*p.op*p.op2)
0181             massSearchReplaceAnyInputTag(p.s, cms.InputTag("b"), cms.InputTag("new"))
0182             self.assertNotEqual(cms.InputTag("new"), p.b.src)
0183             self.assertEqual(cms.InputTag("new"), p.c.src)
0184             self.assertEqual(cms.InputTag("new"), p.c.usrc)
0185             self.assertEqual(cms.InputTag("new"), p.c.nested.src)
0186             self.assertEqual(cms.InputTag("new"), p.c.nested.usrc)
0187             self.assertFalse(p.c.nested.usrc.isTracked())
0188             self.assertNotEqual(cms.InputTag("new"), p.c.nested.src2)
0189             self.assertEqual(cms.InputTag("new"), p.c.nestedv[0].src)
0190             self.assertNotEqual(cms.InputTag("new"), p.c.nestedv[1].src)
0191             self.assertEqual(cms.InputTag("new"), p.c.unestedv[0].src)
0192             self.assertNotEqual(cms.InputTag("new"), p.c.unestedv[1].src)
0193             self.assertNotEqual(cms.InputTag("new"), p.c.vec[0])
0194             self.assertEqual(cms.InputTag("new"), p.c.vec[1])
0195             self.assertNotEqual(cms.InputTag("new"), p.c.vec[2])
0196             self.assertNotEqual(cms.InputTag("new"), p.c.vec[3])
0197             self.assertNotEqual(cms.InputTag("new"), p.c.uvec[0])
0198             self.assertEqual(cms.InputTag("new"), p.c.uvec[1])
0199             self.assertNotEqual(cms.InputTag("new"), p.c.uvec[2])
0200             self.assertNotEqual(cms.InputTag("new"), p.c.uvec[3])
0201             self.assertFalse(p.c.uvec[0].isTracked())
0202             self.assertFalse(p.c.uvec[1].isTracked())
0203             self.assertFalse(p.c.uvec[2].isTracked())
0204             self.assertFalse(p.c.uvec[3].isTracked())
0205             self.assertEqual(cms.InputTag("new"), p.sp.test1.src)
0206             self.assertEqual(cms.InputTag("new"), p.sp.test1.nested.src)
0207             self.assertEqual(cms.InputTag("c"), p.sp.test1.nested.src2)
0208             self.assertEqual(cms.untracked.InputTag("new"), p.sp.test1.nested.usrc)
0209             self.assertEqual(cms.InputTag("c"), p.sp.test2.src)
0210             self.assertEqual(cms.InputTag("new"), p.sp.test2.nested.src)
0211             self.assertEqual(cms.InputTag("c"), p.sp.test2.nested.src2)
0212             self.assertEqual(cms.untracked.InputTag("new"), p.sp.test2.nested.usrc)
0213             self.assertEqual(cms.InputTag("new"), p.op.src)
0214             self.assertEqual(cms.InputTag("new"), p.op.vsrc[0])
0215             self.assertEqual(cms.InputTag("new"), p.op2.src)
0216             self.assertEqual(cms.InputTag("new"), p.op2.vsrc[0])
0217 
0218         def testMassReplaceInputTag(self):
0219             process1 = cms.Process("test")
0220             massReplaceInputTag(process1, "a", "b", False, False, False)
0221             self.assertEqual(process1.dumpPython(), cms.Process('test').dumpPython())
0222             p = cms.Process("test")
0223             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0224             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0225             p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0226                                  nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a")),
0227                                  nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("a")), cms.PSet(src = cms.InputTag("d"))),
0228                                  unestedv = cms.untracked.VPSet(cms.untracked.PSet(src = cms.InputTag("a")), cms.untracked.PSet(src = cms.InputTag("d"))),
0229                                  vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d")),
0230                                  uvec = cms.untracked.VInputTag(cms.untracked.InputTag("a"), cms.untracked.InputTag("b"), cms.untracked.InputTag("c"), cms.InputTag("d")),
0231                                 )
0232             p.d = cms.EDProducer("ab", src=cms.InputTag("a"))
0233             p.e = cms.EDProducer("ab", src=cms.InputTag("a"))
0234             p.f = cms.EDProducer("ab", src=cms.InputTag("a"))
0235             p.g = cms.EDProducer("ab", src=cms.InputTag("a"))
0236             p.h = cms.EDProducer("ab", src=cms.InputTag("a"))
0237             p.i = cms.EDProducer("ab", src=cms.InputTag("a"))
0238             p.sp = SwitchProducerTest(
0239                 test1 = cms.EDProducer("a", src = cms.InputTag("a"),
0240                                        nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a"))
0241                                        ),
0242                 test2 = cms.EDProducer("b", src = cms.InputTag("c"),
0243                                        nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a"))
0244                                        ),
0245             )
0246             p.s1 = cms.Sequence(p.a*p.b*p.c*p.sp)
0247             p.path1 = cms.Path(p.s1)
0248             p.s2 = cms.Sequence(p.d)
0249             p.path2 = cms.Path(p.e)
0250             p.s3 = cms.Sequence(p.f)
0251             p.endpath1 = cms.EndPath(p.s3)
0252             p.endpath2 = cms.EndPath(p.g)
0253             p.t1 = cms.Task(p.h)
0254             p.t2 = cms.Task(p.i)
0255             p.schedule = cms.Schedule()
0256             p.schedule.associate(p.t1, p.t2)
0257             massReplaceInputTag(p, "a", "b", False, False, False)
0258             self.assertEqual(cms.InputTag("b"), p.b.src)
0259             self.assertEqual(cms.InputTag("b"), p.c.nested.src)
0260             self.assertEqual(cms.InputTag("b"), p.c.nested.usrc)
0261             self.assertFalse(p.c.nested.usrc.isTracked())
0262             self.assertEqual(cms.InputTag("b"), p.c.nestedv[0].src)
0263             self.assertEqual(cms.InputTag("b"), p.c.unestedv[0].src)
0264             self.assertEqual(cms.InputTag("b"), p.c.vec[0])
0265             self.assertEqual(cms.InputTag("c"), p.c.vec[2])
0266             self.assertEqual(cms.InputTag("b"), p.c.uvec[0])
0267             self.assertEqual(cms.InputTag("c"), p.c.uvec[2])
0268             self.assertFalse(p.c.uvec[0].isTracked())
0269             self.assertFalse(p.c.uvec[1].isTracked())
0270             self.assertFalse(p.c.uvec[2].isTracked())
0271             self.assertEqual(cms.InputTag("a"), p.d.src)
0272             self.assertEqual(cms.InputTag("b"), p.e.src)
0273             self.assertEqual(cms.InputTag("b"), p.f.src)
0274             self.assertEqual(cms.InputTag("b"), p.g.src)
0275             self.assertEqual(cms.InputTag("b"), p.h.src)
0276             self.assertEqual(cms.InputTag("b"), p.i.src)
0277             self.assertEqual(cms.InputTag("b"), p.sp.test1.src)
0278             self.assertEqual(cms.InputTag("b"), p.sp.test1.nested.src)
0279             self.assertEqual(cms.InputTag("c"), p.sp.test1.nested.src2)
0280             self.assertEqual(cms.untracked.InputTag("b"), p.sp.test1.nested.usrc)
0281             self.assertEqual(cms.InputTag("c"), p.sp.test2.src)
0282             self.assertEqual(cms.InputTag("b"), p.sp.test2.nested.src)
0283             self.assertEqual(cms.InputTag("c"), p.sp.test2.nested.src2)
0284             self.assertEqual(cms.untracked.InputTag("b"), p.sp.test2.nested.usrc)
0285 
0286         def testMassSearchReplaceParam(self):
0287             p = cms.Process("test")
0288             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0289             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0290             p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0291                                  nested = cms.PSet(src = cms.InputTag("c"), src2 = cms.InputTag("b"))
0292                                 )
0293             p.d = cms.EDProducer("ac", src=cms.untracked.InputTag("b"),
0294                                  nested = cms.PSet(src = cms.InputTag("c"), src2 = cms.InputTag("b"))
0295                                 )
0296             p.sp = SwitchProducerTest(
0297                 test1 = cms.EDProducer("a", src = cms.InputTag("b"),
0298                                        nested = cms.PSet(src = cms.InputTag("b"))
0299                                        ),
0300                 test2 = cms.EDProducer("b", src = cms.InputTag("b")),
0301             )
0302             p.s = cms.Sequence(p.a*p.b*p.c*p.d*p.sp)
0303             massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
0304             self.assertEqual(cms.InputTag("a"),p.c.src)
0305             self.assertEqual(cms.InputTag("c"),p.c.nested.src)
0306             self.assertEqual(cms.InputTag("b"),p.c.nested.src2)
0307             self.assertEqual(cms.untracked.InputTag("a"),p.d.src)
0308             self.assertEqual(cms.InputTag("c"),p.d.nested.src)
0309             self.assertEqual(cms.InputTag("b"),p.d.nested.src2)
0310             self.assertEqual(cms.InputTag("a"),p.sp.test1.src)
0311             self.assertEqual(cms.InputTag("b"),p.sp.test1.nested.src)
0312             self.assertEqual(cms.InputTag("a"),p.sp.test2.src)
0313 
0314         def testMassReplaceParam(self):
0315             process1 = cms.Process("test")
0316             massReplaceParameter(process1, "src", cms.InputTag("a"), "b", False)
0317             self.assertEqual(process1.dumpPython(), cms.Process("test").dumpPython())
0318             p = cms.Process("test")
0319             p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0320             p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0321             p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0322                                  nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c")),
0323                                  nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("a")), cms.PSet(src = cms.InputTag("d"))),
0324                                  vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d"))
0325                                 )
0326             p.d = cms.EDProducer("ab", src=cms.InputTag("a"))
0327             p.e = cms.EDProducer("ab", src=cms.InputTag("a"))
0328             p.f = cms.EDProducer("ab", src=cms.InputTag("a"))
0329             p.g = cms.EDProducer("ab", src=cms.InputTag("a"))
0330             p.h = cms.EDProducer("ab", src=cms.InputTag("a"))
0331             p.i = cms.EDProducer("ab", src=cms.InputTag("a"))
0332             p.j = cms.EDProducer("ab", src=cms.untracked.InputTag("a"))
0333             p.sp = SwitchProducerTest(
0334                 test1 = cms.EDProducer("a", src = cms.InputTag("a"),
0335                                        nested = cms.PSet(src = cms.InputTag("a"))
0336                                        ),
0337                 test2 = cms.EDProducer("b", src = cms.InputTag("a")),
0338             )
0339             p.s1 = cms.Sequence(p.a*p.b*p.c*p.sp)
0340             p.path1 = cms.Path(p.s1)
0341             p.s2 = cms.Sequence(p.d)
0342             p.path2 = cms.Path(p.e)
0343             p.s3 = cms.Sequence(p.f)
0344             p.endpath1 = cms.EndPath(p.s3)
0345             p.endpath2 = cms.EndPath(p.g)
0346             p.t1 = cms.Task(p.h)
0347             p.t2 = cms.Task(p.i, p.j)
0348             p.schedule = cms.Schedule()
0349             p.schedule.associate(p.t1, p.t2)
0350             massReplaceParameter(p, "src",cms.InputTag("a"), "b", False)
0351             self.assertEqual(cms.InputTag("gen"), p.a.src)
0352             self.assertEqual(cms.InputTag("b"), p.b.src)
0353             self.assertEqual(cms.InputTag("a"), p.c.vec[0])
0354             self.assertEqual(cms.InputTag("c"), p.c.vec[2])
0355             self.assertEqual(cms.InputTag("a"), p.d.src)
0356             self.assertEqual(cms.InputTag("b"), p.e.src)
0357             self.assertEqual(cms.InputTag("b"), p.f.src)
0358             self.assertEqual(cms.InputTag("b"), p.g.src)
0359             self.assertEqual(cms.InputTag("b"), p.h.src)
0360             self.assertEqual(cms.InputTag("b"), p.i.src)
0361             self.assertEqual(cms.untracked.InputTag("b"), p.j.src)
0362             self.assertEqual(cms.InputTag("b"),p.sp.test1.src)
0363             self.assertEqual(cms.InputTag("a"),p.sp.test1.nested.src)
0364             self.assertEqual(cms.InputTag("b"),p.sp.test2.src)
0365     unittest.main()