File indexing completed on 2023-05-16 00:57:56
0001 from __future__ import print_function
0002 import FWCore.ParameterSet.Config as cms
0003
0004 class MassSearchReplaceAnyInputTagVisitor(object):
0005 """Visitor that travels within a cms.Sequence, looks for a parameter and replace its value
0006 It will climb down within PSets, VPSets and VInputTags to find its target"""
0007 def __init__(self,paramSearch,paramReplace,verbose=False,moduleLabelOnly=False,skipLabelTest=False):
0008 self._paramSearch = self.standardizeInputTagFmt(paramSearch)
0009 self._paramReplace = self.standardizeInputTagFmt(paramReplace)
0010 self._moduleName = ''
0011 self._verbose=verbose
0012 self._moduleLabelOnly=moduleLabelOnly
0013 self._skipLabelTest=skipLabelTest
0014 def doIt(self,pset,base):
0015 if isinstance(pset, cms._Parameterizable):
0016 for name in pset.parameterNames_():
0017
0018
0019 value = getattr(pset,name)
0020 if isinstance(value, cms.PSet) or isinstance(value, cms.EDProducer) or isinstance(value, cms.EDAlias):
0021
0022 self.doIt(value,base+"."+name)
0023 elif value.isCompatibleCMSType(cms.VPSet):
0024 for (i,ps) in enumerate(value): self.doIt(ps, "%s.%s[%d]"%(base,name,i) )
0025 elif value.isCompatibleCMSType(cms.VInputTag) and value:
0026 for (i,n) in enumerate(value):
0027
0028 n = self.standardizeInputTagFmt(n)
0029 if (n == self._paramSearch):
0030 if self._verbose:print("Replace %s.%s[%d] %s ==> %s " % (base, name, i, self._paramSearch, self._paramReplace))
0031 if not value.isTracked():
0032 value[i] = cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
0033 self._paramReplace.getProductInstanceLabel(),
0034 self._paramReplace.getProcessName())
0035 else:
0036 value[i] = self._paramReplace
0037 elif self._moduleLabelOnly and n.moduleLabel == self._paramSearch.moduleLabel:
0038 nrep = n; nrep.moduleLabel = self._paramReplace.moduleLabel
0039 if self._verbose:print("Replace %s.%s[%d] %s ==> %s " % (base, name, i, n, nrep))
0040 value[i] = nrep
0041 elif value.isCompatibleCMSType(cms.InputTag) and value:
0042 if value == self._paramSearch:
0043 if self._verbose:print("Replace %s.%s %s ==> %s " % (base, name, self._paramSearch, self._paramReplace))
0044 from copy import deepcopy
0045 if not value.isTracked():
0046
0047 setattr(pset, name, cms.untracked.InputTag(self._paramReplace.getModuleLabel(),
0048 self._paramReplace.getProductInstanceLabel(),
0049 self._paramReplace.getProcessName()))
0050 else:
0051 setattr(pset, name, deepcopy(self._paramReplace) )
0052 elif self._moduleLabelOnly and value.moduleLabel == self._paramSearch.moduleLabel:
0053 from copy import deepcopy
0054 repl = deepcopy(getattr(pset, name))
0055 repl.moduleLabel = self._paramReplace.moduleLabel
0056 setattr(pset, name, repl)
0057 if self._verbose:print("Replace %s.%s %s ==> %s " % (base, name, value, repl))
0058
0059
0060 @staticmethod
0061 def standardizeInputTagFmt(inputTag):
0062 ''' helper function to ensure that the InputTag is defined as cms.InputTag(str) and not as a plain str '''
0063 if not isinstance(inputTag, cms.InputTag):
0064 return cms.InputTag(inputTag)
0065 return inputTag
0066
0067 def enter(self,visitee):
0068 label = ''
0069 if (not self._skipLabelTest):
0070 if hasattr(visitee,"hasLabel_") and visitee.hasLabel_():
0071 label = visitee.label_()
0072 else: label = '<Module not in a Process>'
0073 else:
0074 label = '<Module label not tested>'
0075 self.doIt(visitee, label)
0076 def leave(self,visitee):
0077 pass
0078
0079 def massSearchReplaceAnyInputTag(sequence, oldInputTag, newInputTag,verbose=False,moduleLabelOnly=False,skipLabelTest=False) :
0080 """Replace InputTag oldInputTag with newInputTag, at any level of nesting within PSets, VPSets, VInputTags..."""
0081 sequence.visit(MassSearchReplaceAnyInputTagVisitor(oldInputTag,newInputTag,verbose=verbose,moduleLabelOnly=moduleLabelOnly,skipLabelTest=skipLabelTest))
0082
0083 def massReplaceInputTag(process,old="rawDataCollector",new="rawDataRepacker",verbose=False,moduleLabelOnly=False,skipLabelTest=False):
0084 for s in process.paths_().keys():
0085 massSearchReplaceAnyInputTag(getattr(process,s), old, new, verbose, moduleLabelOnly, skipLabelTest)
0086 for s in process.endpaths_().keys():
0087 massSearchReplaceAnyInputTag(getattr(process,s), old, new, verbose, moduleLabelOnly, skipLabelTest)
0088 if process.schedule_() is not None:
0089 for task in process.schedule_()._tasks:
0090 massSearchReplaceAnyInputTag(task, old, new, verbose, moduleLabelOnly, skipLabelTest)
0091 return(process)
0092
0093 class MassSearchParamVisitor(object):
0094 """Visitor that travels within a cms.Sequence, looks for a parameter and returns a list of modules that have it"""
0095 def __init__(self,paramName,paramSearch):
0096 self._paramName = paramName
0097 self._paramSearch = paramSearch
0098 self._modules = []
0099 def enter(self,visitee):
0100 if (hasattr(visitee,self._paramName)):
0101 if getattr(visitee,self._paramName) == self._paramSearch:
0102 self._modules.append(visitee)
0103 def leave(self,visitee):
0104 pass
0105 def modules(self):
0106 return self._modules
0107
0108 class MassSearchReplaceParamVisitor(object):
0109 """Visitor that travels within a cms.Sequence, looks for a parameter and replaces its value"""
0110 def __init__(self,paramName,paramSearch,paramValue,verbose=False):
0111 self._paramName = paramName
0112 self._paramValue = paramValue
0113 self._paramSearch = paramSearch
0114 self._verbose = verbose
0115 def enter(self,visitee):
0116 if isinstance(visitee, cms.SwitchProducer):
0117 for modName in visitee.parameterNames_():
0118 self.doIt(getattr(visitee, modName), "%s.%s"%(str(visitee), modName))
0119 else:
0120 self.doIt(visitee, str(visitee))
0121 def doIt(self, mod, name):
0122 if (hasattr(mod,self._paramName)):
0123 if getattr(mod,self._paramName) == self._paramSearch:
0124 if self._verbose:print("Replaced %s.%s: %s => %s" % (name,self._paramName,getattr(mod,self._paramName),self._paramValue))
0125 setattr(mod,self._paramName,self._paramValue)
0126 def leave(self,visitee):
0127 pass
0128
0129 def massSearchReplaceParam(sequence,paramName,paramOldValue,paramValue,verbose=False):
0130 sequence.visit(MassSearchReplaceParamVisitor(paramName,paramOldValue,paramValue,verbose))
0131
0132 def massReplaceParameter(process,name="label",old="rawDataCollector",new="rawDataRepacker",verbose=False):
0133 for s in process.paths_().keys():
0134 massSearchReplaceParam(getattr(process,s),name,old,new,verbose)
0135 for s in process.endpaths_().keys():
0136 massSearchReplaceParam(getattr(process,s),name,old,new,verbose)
0137 if process.schedule_() is not None:
0138 for task in process.schedule_()._tasks:
0139 massSearchReplaceParam(task, name, old, new, verbose)
0140 return(process)
0141
0142 if __name__=="__main__":
0143 import unittest
0144 class SwitchProducerTest(cms.SwitchProducer):
0145 def __init__(self, **kargs):
0146 super(SwitchProducerTest,self).__init__(
0147 dict(
0148 test1 = lambda: (True, -10),
0149 test2 = lambda: (True, -9),
0150 test3 = lambda: (True, -8),
0151 test4 = lambda: (True, -7)
0152 ), **kargs)
0153
0154 class TestModuleCommand(unittest.TestCase):
0155
0156 def testMassSearchReplaceAnyInputTag(self):
0157 p = cms.Process("test")
0158 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0159 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0160 p.c = cms.EDProducer("ac", src=cms.InputTag("b"), usrc=cms.untracked.InputTag("b"),
0161 nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b")),
0162 nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("b")), cms.PSet(src = cms.InputTag("d"))),
0163 unestedv = cms.untracked.VPSet(cms.untracked.PSet(src = cms.InputTag("b")), cms.untracked.PSet(src = cms.InputTag("d"))),
0164 vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d")),
0165 uvec = cms.untracked.VInputTag(cms.untracked.InputTag("a"), cms.untracked.InputTag("b"), cms.untracked.InputTag("c"), cms.untracked.InputTag("d")),
0166 )
0167 p.sp = SwitchProducerTest(
0168 test1 = cms.EDProducer("a", src = cms.InputTag("b"),
0169 nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b"))
0170 ),
0171 test2 = cms.EDProducer("b", src = cms.InputTag("c"),
0172 nested = cms.PSet(src = cms.InputTag("b"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("b"))
0173 ),
0174 )
0175 p.op = cms.EDProducer("op", src = cms.optional.InputTag, unset = cms.optional.InputTag, vsrc = cms.optional.VInputTag, vunset = cms.optional.VInputTag)
0176 p.op2 = cms.EDProducer("op2", src = cms.optional.InputTag, unset = cms.optional.InputTag, vsrc = cms.optional.VInputTag, vunset = cms.optional.VInputTag)
0177 p.op.src="b"
0178 p.op.vsrc = ["b"]
0179 p.op2.src=cms.InputTag("b")
0180 p.op2.vsrc = cms.VInputTag("b")
0181 p.s = cms.Sequence(p.a*p.b*p.c*p.sp*p.op*p.op2)
0182 massSearchReplaceAnyInputTag(p.s, cms.InputTag("b"), cms.InputTag("new"))
0183 self.assertNotEqual(cms.InputTag("new"), p.b.src)
0184 self.assertEqual(cms.InputTag("new"), p.c.src)
0185 self.assertEqual(cms.InputTag("new"), p.c.usrc)
0186 self.assertEqual(cms.InputTag("new"), p.c.nested.src)
0187 self.assertEqual(cms.InputTag("new"), p.c.nested.usrc)
0188 self.assertFalse(p.c.nested.usrc.isTracked())
0189 self.assertNotEqual(cms.InputTag("new"), p.c.nested.src2)
0190 self.assertEqual(cms.InputTag("new"), p.c.nestedv[0].src)
0191 self.assertNotEqual(cms.InputTag("new"), p.c.nestedv[1].src)
0192 self.assertEqual(cms.InputTag("new"), p.c.unestedv[0].src)
0193 self.assertNotEqual(cms.InputTag("new"), p.c.unestedv[1].src)
0194 self.assertNotEqual(cms.InputTag("new"), p.c.vec[0])
0195 self.assertEqual(cms.InputTag("new"), p.c.vec[1])
0196 self.assertNotEqual(cms.InputTag("new"), p.c.vec[2])
0197 self.assertNotEqual(cms.InputTag("new"), p.c.vec[3])
0198 self.assertNotEqual(cms.InputTag("new"), p.c.uvec[0])
0199 self.assertEqual(cms.InputTag("new"), p.c.uvec[1])
0200 self.assertNotEqual(cms.InputTag("new"), p.c.uvec[2])
0201 self.assertNotEqual(cms.InputTag("new"), p.c.uvec[3])
0202 self.assertFalse(p.c.uvec[0].isTracked())
0203 self.assertFalse(p.c.uvec[1].isTracked())
0204 self.assertFalse(p.c.uvec[2].isTracked())
0205 self.assertFalse(p.c.uvec[3].isTracked())
0206 self.assertEqual(cms.InputTag("new"), p.sp.test1.src)
0207 self.assertEqual(cms.InputTag("new"), p.sp.test1.nested.src)
0208 self.assertEqual(cms.InputTag("c"), p.sp.test1.nested.src2)
0209 self.assertEqual(cms.untracked.InputTag("new"), p.sp.test1.nested.usrc)
0210 self.assertEqual(cms.InputTag("c"), p.sp.test2.src)
0211 self.assertEqual(cms.InputTag("new"), p.sp.test2.nested.src)
0212 self.assertEqual(cms.InputTag("c"), p.sp.test2.nested.src2)
0213 self.assertEqual(cms.untracked.InputTag("new"), p.sp.test2.nested.usrc)
0214 self.assertEqual(cms.InputTag("new"), p.op.src)
0215 self.assertEqual(cms.InputTag("new"), p.op.vsrc[0])
0216 self.assertEqual(cms.InputTag("new"), p.op2.src)
0217 self.assertEqual(cms.InputTag("new"), p.op2.vsrc[0])
0218
0219 def testMassReplaceInputTag(self):
0220 process1 = cms.Process("test")
0221 massReplaceInputTag(process1, "a", "b", False, False, False)
0222 self.assertEqual(process1.dumpPython(), cms.Process('test').dumpPython())
0223 p = cms.Process("test")
0224 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0225 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0226 p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0227 nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a")),
0228 nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("a")), cms.PSet(src = cms.InputTag("d"))),
0229 unestedv = cms.untracked.VPSet(cms.untracked.PSet(src = cms.InputTag("a")), cms.untracked.PSet(src = cms.InputTag("d"))),
0230 vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d")),
0231 uvec = cms.untracked.VInputTag(cms.untracked.InputTag("a"), cms.untracked.InputTag("b"), cms.untracked.InputTag("c"), cms.InputTag("d")),
0232 )
0233 p.d = cms.EDProducer("ab", src=cms.InputTag("a"))
0234 p.e = cms.EDProducer("ab", src=cms.InputTag("a"))
0235 p.f = cms.EDProducer("ab", src=cms.InputTag("a"))
0236 p.g = cms.EDProducer("ab", src=cms.InputTag("a"))
0237 p.h = cms.EDProducer("ab", src=cms.InputTag("a"))
0238 p.i = cms.EDProducer("ab", src=cms.InputTag("a"))
0239 p.sp = SwitchProducerTest(
0240 test1 = cms.EDProducer("a", src = cms.InputTag("a"),
0241 nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a"))
0242 ),
0243 test2 = cms.EDProducer("b", src = cms.InputTag("c"),
0244 nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c"), usrc = cms.untracked.InputTag("a"))
0245 ),
0246 )
0247 p.s1 = cms.Sequence(p.a*p.b*p.c*p.sp)
0248 p.path1 = cms.Path(p.s1)
0249 p.s2 = cms.Sequence(p.d)
0250 p.path2 = cms.Path(p.e)
0251 p.s3 = cms.Sequence(p.f)
0252 p.endpath1 = cms.EndPath(p.s3)
0253 p.endpath2 = cms.EndPath(p.g)
0254 p.t1 = cms.Task(p.h)
0255 p.t2 = cms.Task(p.i)
0256 p.schedule = cms.Schedule()
0257 p.schedule.associate(p.t1, p.t2)
0258 massReplaceInputTag(p, "a", "b", False, False, False)
0259 self.assertEqual(cms.InputTag("b"), p.b.src)
0260 self.assertEqual(cms.InputTag("b"), p.c.nested.src)
0261 self.assertEqual(cms.InputTag("b"), p.c.nested.usrc)
0262 self.assertFalse(p.c.nested.usrc.isTracked())
0263 self.assertEqual(cms.InputTag("b"), p.c.nestedv[0].src)
0264 self.assertEqual(cms.InputTag("b"), p.c.unestedv[0].src)
0265 self.assertEqual(cms.InputTag("b"), p.c.vec[0])
0266 self.assertEqual(cms.InputTag("c"), p.c.vec[2])
0267 self.assertEqual(cms.InputTag("b"), p.c.uvec[0])
0268 self.assertEqual(cms.InputTag("c"), p.c.uvec[2])
0269 self.assertFalse(p.c.uvec[0].isTracked())
0270 self.assertFalse(p.c.uvec[1].isTracked())
0271 self.assertFalse(p.c.uvec[2].isTracked())
0272 self.assertEqual(cms.InputTag("a"), p.d.src)
0273 self.assertEqual(cms.InputTag("b"), p.e.src)
0274 self.assertEqual(cms.InputTag("b"), p.f.src)
0275 self.assertEqual(cms.InputTag("b"), p.g.src)
0276 self.assertEqual(cms.InputTag("b"), p.h.src)
0277 self.assertEqual(cms.InputTag("b"), p.i.src)
0278 self.assertEqual(cms.InputTag("b"), p.sp.test1.src)
0279 self.assertEqual(cms.InputTag("b"), p.sp.test1.nested.src)
0280 self.assertEqual(cms.InputTag("c"), p.sp.test1.nested.src2)
0281 self.assertEqual(cms.untracked.InputTag("b"), p.sp.test1.nested.usrc)
0282 self.assertEqual(cms.InputTag("c"), p.sp.test2.src)
0283 self.assertEqual(cms.InputTag("b"), p.sp.test2.nested.src)
0284 self.assertEqual(cms.InputTag("c"), p.sp.test2.nested.src2)
0285 self.assertEqual(cms.untracked.InputTag("b"), p.sp.test2.nested.usrc)
0286
0287 def testMassSearchReplaceParam(self):
0288 p = cms.Process("test")
0289 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0290 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0291 p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0292 nested = cms.PSet(src = cms.InputTag("c"), src2 = cms.InputTag("b"))
0293 )
0294 p.d = cms.EDProducer("ac", src=cms.untracked.InputTag("b"),
0295 nested = cms.PSet(src = cms.InputTag("c"), src2 = cms.InputTag("b"))
0296 )
0297 p.sp = SwitchProducerTest(
0298 test1 = cms.EDProducer("a", src = cms.InputTag("b"),
0299 nested = cms.PSet(src = cms.InputTag("b"))
0300 ),
0301 test2 = cms.EDProducer("b", src = cms.InputTag("b")),
0302 )
0303 p.s = cms.Sequence(p.a*p.b*p.c*p.d*p.sp)
0304 massSearchReplaceParam(p.s,"src",cms.InputTag("b"),"a")
0305 self.assertEqual(cms.InputTag("a"),p.c.src)
0306 self.assertEqual(cms.InputTag("c"),p.c.nested.src)
0307 self.assertEqual(cms.InputTag("b"),p.c.nested.src2)
0308 self.assertEqual(cms.untracked.InputTag("a"),p.d.src)
0309 self.assertEqual(cms.InputTag("c"),p.d.nested.src)
0310 self.assertEqual(cms.InputTag("b"),p.d.nested.src2)
0311 self.assertEqual(cms.InputTag("a"),p.sp.test1.src)
0312 self.assertEqual(cms.InputTag("b"),p.sp.test1.nested.src)
0313 self.assertEqual(cms.InputTag("a"),p.sp.test2.src)
0314
0315 def testMassReplaceParam(self):
0316 process1 = cms.Process("test")
0317 massReplaceParameter(process1, "src", cms.InputTag("a"), "b", False)
0318 self.assertEqual(process1.dumpPython(), cms.Process("test").dumpPython())
0319 p = cms.Process("test")
0320 p.a = cms.EDProducer("a", src=cms.InputTag("gen"))
0321 p.b = cms.EDProducer("ab", src=cms.InputTag("a"))
0322 p.c = cms.EDProducer("ac", src=cms.InputTag("b"),
0323 nested = cms.PSet(src = cms.InputTag("a"), src2 = cms.InputTag("c")),
0324 nestedv = cms.VPSet(cms.PSet(src = cms.InputTag("a")), cms.PSet(src = cms.InputTag("d"))),
0325 vec = cms.VInputTag(cms.InputTag("a"), cms.InputTag("b"), cms.InputTag("c"), cms.InputTag("d"))
0326 )
0327 p.d = cms.EDProducer("ab", src=cms.InputTag("a"))
0328 p.e = cms.EDProducer("ab", src=cms.InputTag("a"))
0329 p.f = cms.EDProducer("ab", src=cms.InputTag("a"))
0330 p.g = cms.EDProducer("ab", src=cms.InputTag("a"))
0331 p.h = cms.EDProducer("ab", src=cms.InputTag("a"))
0332 p.i = cms.EDProducer("ab", src=cms.InputTag("a"))
0333 p.j = cms.EDProducer("ab", src=cms.untracked.InputTag("a"))
0334 p.sp = SwitchProducerTest(
0335 test1 = cms.EDProducer("a", src = cms.InputTag("a"),
0336 nested = cms.PSet(src = cms.InputTag("a"))
0337 ),
0338 test2 = cms.EDProducer("b", src = cms.InputTag("a")),
0339 )
0340 p.s1 = cms.Sequence(p.a*p.b*p.c*p.sp)
0341 p.path1 = cms.Path(p.s1)
0342 p.s2 = cms.Sequence(p.d)
0343 p.path2 = cms.Path(p.e)
0344 p.s3 = cms.Sequence(p.f)
0345 p.endpath1 = cms.EndPath(p.s3)
0346 p.endpath2 = cms.EndPath(p.g)
0347 p.t1 = cms.Task(p.h)
0348 p.t2 = cms.Task(p.i, p.j)
0349 p.schedule = cms.Schedule()
0350 p.schedule.associate(p.t1, p.t2)
0351 massReplaceParameter(p, "src",cms.InputTag("a"), "b", False)
0352 self.assertEqual(cms.InputTag("gen"), p.a.src)
0353 self.assertEqual(cms.InputTag("b"), p.b.src)
0354 self.assertEqual(cms.InputTag("a"), p.c.vec[0])
0355 self.assertEqual(cms.InputTag("c"), p.c.vec[2])
0356 self.assertEqual(cms.InputTag("a"), p.d.src)
0357 self.assertEqual(cms.InputTag("b"), p.e.src)
0358 self.assertEqual(cms.InputTag("b"), p.f.src)
0359 self.assertEqual(cms.InputTag("b"), p.g.src)
0360 self.assertEqual(cms.InputTag("b"), p.h.src)
0361 self.assertEqual(cms.InputTag("b"), p.i.src)
0362 self.assertEqual(cms.untracked.InputTag("b"), p.j.src)
0363 self.assertEqual(cms.InputTag("b"),p.sp.test1.src)
0364 self.assertEqual(cms.InputTag("a"),p.sp.test1.nested.src)
0365 self.assertEqual(cms.InputTag("b"),p.sp.test2.src)
0366 unittest.main()