File indexing completed on 2024-12-01 23:40:19
0001 import FWCore.ParameterSet.Config as cms
0002
0003 class MassSearchReplaceAnyInputTagVisitor(object):
0004 """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
0005 It will climb down within PSets, VPSets and VInputTags to find its target"""
0006 def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False,skipLabelTest=False):
0007 self._paramSearch = self.standardizeInputTagFmt(paramSearch)
0008 self._paramReplace = self.standardizeInputTagFmt(paramReplace)
0009 self._moduleName = ''
0010 self._verbose=verbose
0011 self._moduleLabelOnly=moduleLabelOnly
0012 self._skipLabelTest=skipLabelTest
0013 def doIt(self,pset,base):
0014 if isinstance(pset, cms._Parameterizable):
0015 for name in pset.parameterNames_():
0016
0017
0018 value = getattr(pset,name)
0019 if isinstance(value, cms.PSet) or isinstance(value, cms.EDProducer) or isinstance(value, cms.EDAlias):
0020
0021 self.doIt(value,base+"."+name)
0022 elif value.isCompatibleCMSType(cms.VPSet):
0023 for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
0024 elif value.isCompatibleCMSType(cms.VInputTag) and value:
0025 for (i,n) in enumerate(value):
0026
0027 n = self.standardizeInputTagFmt(n)
0028 if (n == self._paramSearch):
0029 if self._verbose:print("Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace))
0030 if not value.isTracked():
0031 value[i] = cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
0032 self._paramReplace.getProductInstanceLabel(),
0033 self._paramReplace.getProcessName())
0034 else:
0035 value[i] = self._paramReplace
0036 elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
0037 nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
0038 if self._verbose:print("Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep))
0039 value[i] = nrep
0040 elif value.isCompatibleCMSType(cms.InputTag) and value:
0041 if value == self._paramSearch:
0042 if self._verbose:print("Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace))
0043 from copy import deepcopy
0044 if not value.isTracked():
0045
0046 setattr(pset, name, cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
0047 self._paramReplace.getProductInstanceLabel(),
0048 self._paramReplace.getProcessName()))
0049 else:
0050 setattr(pset, name, deepcopy(self._paramReplace) )
0051 elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
0052 from copy import deepcopy
0053 repl = deepcopy(getattr(pset, name))
0054 repl.moduleLabel = self._paramReplace.moduleLabel
0055 setattr(pset, name, repl)
0056 if self._verbose:print("Replace %s.%s %s ==> %s " % (base, name, value, repl))
0057
0058
0059 @staticmethod
0060 def standardizeInputTagFmt(inputTag):
0061 ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
0062 if not isinstance(inputTag, cms.InputTag):
0063 return cms.InputTag(inputTag)
0064 return inputTag
0065
0066 def enter(self,visitee):
0067 label = ''
0068 if (not self._skipLabelTest):
0069 if hasattr(visitee,"hasLabel_") and visitee.hasLabel_():
0070 label = visitee.label_()
0071 else: label = '<Module not in a Process>'
0072 else:
0073 label = '<Module label not tested>'
0074 self.doIt(visitee, label)
0075 def leave(self,visitee):
0076 pass
0077
0078 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False,skipLabelTest=False) :
0079 """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
0080 sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly,skipLabelTest=skipLabelTest))
0081
0082 def massReplaceInputTag(process,old="rawDataCollector",new="rawDataRepacker",verbose=False,moduleLabelOnly=False,skipLabelTest=False):
0083 for s in process.paths_().keys():
0084 massSearchReplaceAnyInputTag(getattr(process,s), old, new, verbose, moduleLabelOnly, skipLabelTest)
0085 for s in process.endpaths_().keys():
0086 massSearchReplaceAnyInputTag(getattr(process,s), old, new, verbose, moduleLabelOnly, skipLabelTest)
0087 if process.schedule_() is not None:
0088 for task in process.schedule_()._tasks:
0089 massSearchReplaceAnyInputTag(task, old, new, verbose, moduleLabelOnly, skipLabelTest)
0090 return(process)
0091
0092 class MassSearchParamVisitor(object):
0093 """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
0094 def __init__(self,paramName,paramSearch):
0095 self._paramName = paramName
0096 self._paramSearch = paramSearch
0097 self._modules = []
0098 def enter(self,visitee):
0099 if (hasattr(visitee,self._paramName)):
0100 if getattr(visitee,self._paramName) == self._paramSearch:
0101 self._modules.append(visitee)
0102 def leave(self,visitee):
0103 pass
0104 def modules(self):
0105 return self._modules
0106
0107 class MassSearchReplaceParamVisitor(object):
0108 """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
0109 def __init__(self,paramName,paramSearch,paramValue,verbose=False):
0110 self._paramName = paramName
0111 self._paramValue = paramValue
0112 self._paramSearch = paramSearch
0113 self._verbose = verbose
0114 def enter(self,visitee):
0115 if isinstance(visitee, cms.SwitchProducer):
0116 for modName in visitee.parameterNames_():
0117 self.doIt(getattr(visitee, modName), "%s.%s"%(str(visitee), modName))
0118 else:
0119 self.doIt(visitee, str(visitee))
0120 def doIt(self, mod, name):
0121 if (hasattr(mod,self._paramName)):
0122 if getattr(mod,self._paramName) == self._paramSearch:
0123 if self._verbose:print("Replaced %s.%s: %s => %s" % (name,self._paramName,getattr(mod,self._paramName),self._paramValue))
0124 setattr(mod,self._paramName,self._paramValue)
0125 def leave(self,visitee):
0126 pass
0127
0128 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue,verbose=False):
0129 sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue,verbose))
0130
0131 def massReplaceParameter(process,name="label",old="rawDataCollector",new="rawDataRepacker",verbose=False):
0132 for s in process.paths_().keys():
0133 massSearchReplaceParam(getattr(process,s),name,old,new,verbose)
0134 for s in process.endpaths_().keys():
0135 massSearchReplaceParam(getattr(process,s),name,old,new,verbose)
0136 if process.schedule_() is not None:
0137 for task in process.schedule_()._tasks:
0138 massSearchReplaceParam(task, name, old, new, verbose)
0139 return(process)
0140
0141 if __name__=="__main__":
0142 import unittest
0143 class SwitchProducerTest(cms.SwitchProducer):
0144 def __init__(self, **kargs):
0145 super(SwitchProducerTest,self).__init__(
0146 dict(
0147 test1 = lambda: (True, -10),
0148 test2 = lambda: (True, -9),
0149 test3 = lambda: (True, -8),
0150 test4 = lambda: (True, -7)
0151 ), **kargs)
0152
0153 class TestModuleCommand(unittest.TestCase):
0154
0155 def testMassSearchReplaceAnyInputTag(self):
0156 p = cms.Process("test")
0157 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0158 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0159 p.c = cms.EDProducer("ac", src=cms.InputTag("b"), usrc=cms.untracked.InputTag("b"),
0160 nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b")),
0161 nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("b")), cms.PSet(src = cms.InputTag("d"))),
0162 unestedv = cms.untracked.VPSet(cms.untracked.PSet(src = cms.InputTag("b")), cms.untracked.PSet(src = cms.InputTag("d"))),
0163 vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d")),
0164 uvec = cms.untracked.VInputTag(cms.untracked.InputTag("a"), cms.untracked.InputTag("b"), cms.untracked.InputTag("c"), cms.untracked.InputTag("d")),
0165 )
0166 p.sp = SwitchProducerTest(
0167 test1 = cms.EDProducer("a", src = cms.InputTag("b"),
0168 nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b"))
0169 ),
0170 test2 = cms.EDProducer("b", src = cms.InputTag("c"),
0171 nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b"))
0172 ),
0173 )
0174 p.op = cms.EDProducer("op", src = cms.optional.InputTag, unset = cms.optional.InputTag, vsrc = cms.optional.VInputTag, vunset = cms.optional.VInputTag)
0175 p.op2 = cms.EDProducer("op2", src = cms.optional.InputTag, unset = cms.optional.InputTag, vsrc = cms.optional.VInputTag, vunset = cms.optional.VInputTag)
0176 p.op.src="b"
0177 p.op.vsrc = ["b"]
0178 p.op2.src=cms.InputTag("b")
0179 p.op2.vsrc = cms.VInputTag("b")
0180 p.s = cms.Sequence(p.a*p.b*p.c*p.sp*p.op*p.op2)
0181 massSearchReplaceAnyInputTag(p.s, cms.InputTag("b"), cms.InputTag("new"))
0182 self.assertNotEqual(cms.InputTag("new"), p.b.src)
0183 self.assertEqual(cms.InputTag("new"), p.c.src)
0184 self.assertEqual(cms.InputTag("new"), p.c.usrc)
0185 self.assertEqual(cms.InputTag("new"), p.c.nested.src)
0186 self.assertEqual(cms.InputTag("new"), p.c.nested.usrc)
0187 self.assertFalse(p.c.nested.usrc.isTracked())
0188 self.assertNotEqual(cms.InputTag("new"), p.c.nested.src2)
0189 self.assertEqual(cms.InputTag("new"), p.c.nestedv[0].src)
0190 self.assertNotEqual(cms.InputTag("new"), p.c.nestedv[1].src)
0191 self.assertEqual(cms.InputTag("new"), p.c.unestedv[0].src)
0192 self.assertNotEqual(cms.InputTag("new"), p.c.unestedv[1].src)
0193 self.assertNotEqual(cms.InputTag("new"), p.c.vec[0])
0194 self.assertEqual(cms.InputTag("new"), p.c.vec[1])
0195 self.assertNotEqual(cms.InputTag("new"), p.c.vec[2])
0196 self.assertNotEqual(cms.InputTag("new"), p.c.vec[3])
0197 self.assertNotEqual(cms.InputTag("new"), p.c.uvec[0])
0198 self.assertEqual(cms.InputTag("new"), p.c.uvec[1])
0199 self.assertNotEqual(cms.InputTag("new"), p.c.uvec[2])
0200 self.assertNotEqual(cms.InputTag("new"), p.c.uvec[3])
0201 self.assertFalse(p.c.uvec[0].isTracked())
0202 self.assertFalse(p.c.uvec[1].isTracked())
0203 self.assertFalse(p.c.uvec[2].isTracked())
0204 self.assertFalse(p.c.uvec[3].isTracked())
0205 self.assertEqual(cms.InputTag("new"), p.sp.test1.src)
0206 self.assertEqual(cms.InputTag("new"), p.sp.test1.nested.src)
0207 self.assertEqual(cms.InputTag("c"), p.sp.test1.nested.src2)
0208 self.assertEqual(cms.untracked.InputTag("new"), p.sp.test1.nested.usrc)
0209 self.assertEqual(cms.InputTag("c"), p.sp.test2.src)
0210 self.assertEqual(cms.InputTag("new"), p.sp.test2.nested.src)
0211 self.assertEqual(cms.InputTag("c"), p.sp.test2.nested.src2)
0212 self.assertEqual(cms.untracked.InputTag("new"), p.sp.test2.nested.usrc)
0213 self.assertEqual(cms.InputTag("new"), p.op.src)
0214 self.assertEqual(cms.InputTag("new"), p.op.vsrc[0])
0215 self.assertEqual(cms.InputTag("new"), p.op2.src)
0216 self.assertEqual(cms.InputTag("new"), p.op2.vsrc[0])
0217
0218 def testMassReplaceInputTag(self):
0219 process1 = cms.Process("test")
0220 massReplaceInputTag(process1, "a", "b", False, False, False)
0221 self.assertEqual(process1.dumpPython(), cms.Process('test').dumpPython())
0222 p = cms.Process("test")
0223 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0224 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0225 p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0226 nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a")),
0227 nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("a")), cms.PSet(src = cms.InputTag("d"))),
0228 unestedv = cms.untracked.VPSet(cms.untracked.PSet(src = cms.InputTag("a")), cms.untracked.PSet(src = cms.InputTag("d"))),
0229 vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d")),
0230 uvec = cms.untracked.VInputTag(cms.untracked.InputTag("a"), cms.untracked.InputTag("b"), cms.untracked.InputTag("c"), cms.InputTag("d")),
0231 )
0232 p.d = cms.EDProducer("ab", src=cms.InputTag("a"))
0233 p.e = cms.EDProducer("ab", src=cms.InputTag("a"))
0234 p.f = cms.EDProducer("ab", src=cms.InputTag("a"))
0235 p.g = cms.EDProducer("ab", src=cms.InputTag("a"))
0236 p.h = cms.EDProducer("ab", src=cms.InputTag("a"))
0237 p.i = cms.EDProducer("ab", src=cms.InputTag("a"))
0238 p.sp = SwitchProducerTest(
0239 test1 = cms.EDProducer("a", src = cms.InputTag("a"),
0240 nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a"))
0241 ),
0242 test2 = cms.EDProducer("b", src = cms.InputTag("c"),
0243 nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a"))
0244 ),
0245 )
0246 p.s1 = cms.Sequence(p.a*p.b*p.c*p.sp)
0247 p.path1 = cms.Path(p.s1)
0248 p.s2 = cms.Sequence(p.d)
0249 p.path2 = cms.Path(p.e)
0250 p.s3 = cms.Sequence(p.f)
0251 p.endpath1 = cms.EndPath(p.s3)
0252 p.endpath2 = cms.EndPath(p.g)
0253 p.t1 = cms.Task(p.h)
0254 p.t2 = cms.Task(p.i)
0255 p.schedule = cms.Schedule()
0256 p.schedule.associate(p.t1, p.t2)
0257 massReplaceInputTag(p, "a", "b", False, False, False)
0258 self.assertEqual(cms.InputTag("b"), p.b.src)
0259 self.assertEqual(cms.InputTag("b"), p.c.nested.src)
0260 self.assertEqual(cms.InputTag("b"), p.c.nested.usrc)
0261 self.assertFalse(p.c.nested.usrc.isTracked())
0262 self.assertEqual(cms.InputTag("b"), p.c.nestedv[0].src)
0263 self.assertEqual(cms.InputTag("b"), p.c.unestedv[0].src)
0264 self.assertEqual(cms.InputTag("b"), p.c.vec[0])
0265 self.assertEqual(cms.InputTag("c"), p.c.vec[2])
0266 self.assertEqual(cms.InputTag("b"), p.c.uvec[0])
0267 self.assertEqual(cms.InputTag("c"), p.c.uvec[2])
0268 self.assertFalse(p.c.uvec[0].isTracked())
0269 self.assertFalse(p.c.uvec[1].isTracked())
0270 self.assertFalse(p.c.uvec[2].isTracked())
0271 self.assertEqual(cms.InputTag("a"), p.d.src)
0272 self.assertEqual(cms.InputTag("b"), p.e.src)
0273 self.assertEqual(cms.InputTag("b"), p.f.src)
0274 self.assertEqual(cms.InputTag("b"), p.g.src)
0275 self.assertEqual(cms.InputTag("b"), p.h.src)
0276 self.assertEqual(cms.InputTag("b"), p.i.src)
0277 self.assertEqual(cms.InputTag("b"), p.sp.test1.src)
0278 self.assertEqual(cms.InputTag("b"), p.sp.test1.nested.src)
0279 self.assertEqual(cms.InputTag("c"), p.sp.test1.nested.src2)
0280 self.assertEqual(cms.untracked.InputTag("b"), p.sp.test1.nested.usrc)
0281 self.assertEqual(cms.InputTag("c"), p.sp.test2.src)
0282 self.assertEqual(cms.InputTag("b"), p.sp.test2.nested.src)
0283 self.assertEqual(cms.InputTag("c"), p.sp.test2.nested.src2)
0284 self.assertEqual(cms.untracked.InputTag("b"), p.sp.test2.nested.usrc)
0285
0286 def testMassSearchReplaceParam(self):
0287 p = cms.Process("test")
0288 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0289 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0290 p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0291 nested = cms.PSet(src = cms.InputTag("c"), src2 = cms.InputTag("b"))
0292 )
0293 p.d = cms.EDProducer("ac", src=cms.untracked.InputTag("b"),
0294 nested = cms.PSet(src = cms.InputTag("c"), src2 = cms.InputTag("b"))
0295 )
0296 p.sp = SwitchProducerTest(
0297 test1 = cms.EDProducer("a", src = cms.InputTag("b"),
0298 nested = cms.PSet(src = cms.InputTag("b"))
0299 ),
0300 test2 = cms.EDProducer("b", src = cms.InputTag("b")),
0301 )
0302 p.s = cms.Sequence(p.a*p.b*p.c*p.d*p.sp)
0303 massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
0304 self.assertEqual(cms.InputTag("a"),p.c.src)
0305 self.assertEqual(cms.InputTag("c"),p.c.nested.src)
0306 self.assertEqual(cms.InputTag("b"),p.c.nested.src2)
0307 self.assertEqual(cms.untracked.InputTag("a"),p.d.src)
0308 self.assertEqual(cms.InputTag("c"),p.d.nested.src)
0309 self.assertEqual(cms.InputTag("b"),p.d.nested.src2)
0310 self.assertEqual(cms.InputTag("a"),p.sp.test1.src)
0311 self.assertEqual(cms.InputTag("b"),p.sp.test1.nested.src)
0312 self.assertEqual(cms.InputTag("a"),p.sp.test2.src)
0313
0314 def testMassReplaceParam(self):
0315 process1 = cms.Process("test")
0316 massReplaceParameter(process1, "src", cms.InputTag("a"), "b", False)
0317 self.assertEqual(process1.dumpPython(), cms.Process("test").dumpPython())
0318 p = cms.Process("test")
0319 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0320 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0321 p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0322 nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c")),
0323 nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("a")), cms.PSet(src = cms.InputTag("d"))),
0324 vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d"))
0325 )
0326 p.d = cms.EDProducer("ab", src=cms.InputTag("a"))
0327 p.e = cms.EDProducer("ab", src=cms.InputTag("a"))
0328 p.f = cms.EDProducer("ab", src=cms.InputTag("a"))
0329 p.g = cms.EDProducer("ab", src=cms.InputTag("a"))
0330 p.h = cms.EDProducer("ab", src=cms.InputTag("a"))
0331 p.i = cms.EDProducer("ab", src=cms.InputTag("a"))
0332 p.j = cms.EDProducer("ab", src=cms.untracked.InputTag("a"))
0333 p.sp = SwitchProducerTest(
0334 test1 = cms.EDProducer("a", src = cms.InputTag("a"),
0335 nested = cms.PSet(src = cms.InputTag("a"))
0336 ),
0337 test2 = cms.EDProducer("b", src = cms.InputTag("a")),
0338 )
0339 p.s1 = cms.Sequence(p.a*p.b*p.c*p.sp)
0340 p.path1 = cms.Path(p.s1)
0341 p.s2 = cms.Sequence(p.d)
0342 p.path2 = cms.Path(p.e)
0343 p.s3 = cms.Sequence(p.f)
0344 p.endpath1 = cms.EndPath(p.s3)
0345 p.endpath2 = cms.EndPath(p.g)
0346 p.t1 = cms.Task(p.h)
0347 p.t2 = cms.Task(p.i, p.j)
0348 p.schedule = cms.Schedule()
0349 p.schedule.associate(p.t1, p.t2)
0350 massReplaceParameter(p, "src",cms.InputTag("a"), "b", False)
0351 self.assertEqual(cms.InputTag("gen"), p.a.src)
0352 self.assertEqual(cms.InputTag("b"), p.b.src)
0353 self.assertEqual(cms.InputTag("a"), p.c.vec[0])
0354 self.assertEqual(cms.InputTag("c"), p.c.vec[2])
0355 self.assertEqual(cms.InputTag("a"), p.d.src)
0356 self.assertEqual(cms.InputTag("b"), p.e.src)
0357 self.assertEqual(cms.InputTag("b"), p.f.src)
0358 self.assertEqual(cms.InputTag("b"), p.g.src)
0359 self.assertEqual(cms.InputTag("b"), p.h.src)
0360 self.assertEqual(cms.InputTag("b"), p.i.src)
0361 self.assertEqual(cms.untracked.InputTag("b"), p.j.src)
0362 self.assertEqual(cms.InputTag("b"),p.sp.test1.src)
0363 self.assertEqual(cms.InputTag("a"),p.sp.test1.nested.src)
0364 self.assertEqual(cms.InputTag("b"),p.sp.test2.src)
0365 unittest.main()