Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:23:46

0001 from PhysicsTools.NanoAODTools.postprocessing.framework.eventloop import Module
0002 from PhysicsTools.NanoAODTools.postprocessing.framework.datamodel import Collection
0003 import ROOT
0004 import numpy as np
0005 import itertools
0006 ROOT.PyConfig.IgnoreCommandLineOptions = True
0007 
0008 _rootLeafType2rootBranchType = {
0009     'UChar_t': 'b',
0010     'Char_t': 'B',
0011     'UInt_t': 'i',
0012     'Int_t': 'I',
0013     'Float_t': 'F',
0014     'Double_t': 'D',
0015     'ULong64_t': 'l',
0016     'Long64_t': 'L',
0017     'Bool_t': 'O'
0018 }
0019 
0020 
0021 class collectionMerger(Module):
0022     def __init__(self,
0023                  input,
0024                  output,
0025                  sortkey=lambda x: x.pt,
0026                  reverse=True,
0027                  selector=None,
0028                  maxObjects=None):
0029         self.input = input
0030         self.output = output
0031         self.nInputs = len(self.input)
0032         self.sortkey = lambda obj_j_i1: sortkey(obj_j_i1[0])
0033         self.reverse = reverse
0034         # pass dict([(collection_name,lambda obj : selection(obj)])
0035         self.selector = [(selector[coll] if coll in selector else
0036                           (lambda x: True))
0037                          for coll in self.input] if selector else None
0038         # save only the first maxObjects objects passing the selection in the merged collection
0039         self.maxObjects = maxObjects
0040         self.branchType = {}
0041         pass
0042 
0043     def beginJob(self):
0044         pass
0045 
0046     def endJob(self):
0047         pass
0048 
0049     def beginFile(self, inputFile, outputFile, inputTree, wrappedOutputTree):
0050 
0051         # Find list of activated branches in input tree
0052         _brlist_in = inputTree.GetListOfBranches()
0053         branches_in = set(
0054             [_brlist_in.At(i) for i in range(_brlist_in.GetEntries())])
0055         branches_in = [
0056             x for x in branches_in if inputTree.GetBranchStatus(x.GetName())
0057         ]
0058 
0059         # Find list of activated branches in output tree
0060         _brlist_out = wrappedOutputTree._tree.GetListOfBranches()
0061         branches_out = set(
0062             [_brlist_out.At(i) for i in range(_brlist_out.GetEntries())])
0063         branches_out = [
0064             x for x in branches_out
0065             if wrappedOutputTree._tree.GetBranchStatus(x.GetName())
0066         ]
0067 
0068         # Use both
0069         branches = branches_in + branches_out
0070 
0071         # Only keep branches with right collection name
0072         self.brlist_sep = [
0073             self.filterBranchNames(branches, x) for x in self.input
0074         ]
0075         self.brlist_all = set(itertools.chain(*(self.brlist_sep)))
0076 
0077         self.is_there = np.zeros(shape=(len(self.brlist_all), self.nInputs),
0078                                  dtype=bool)
0079         for bridx, br in enumerate(self.brlist_all):
0080             for j in range(self.nInputs):
0081                 if br in self.brlist_sep[j]:
0082                     self.is_there[bridx][j] = True
0083 
0084         # Create output branches
0085         self.out = wrappedOutputTree
0086         for br in self.brlist_all:
0087             self.out.branch("%s_%s" % (self.output, br),
0088                             _rootLeafType2rootBranchType[self.branchType[br]],
0089                             lenVar="n%s" % self.output)
0090 
0091     def endFile(self, inputFile, outputFile, inputTree, wrappedOutputTree):
0092         pass
0093 
0094     def filterBranchNames(self, branches, collection):
0095         out = []
0096         for br in branches:
0097             name = br.GetName()
0098             if not name.startswith(collection + '_'):
0099                 continue
0100             out.append(name.replace(collection + '_', ''))
0101             self.branchType[out[-1]] = br.FindLeaf(br.GetName()).GetTypeName()
0102         return out
0103 
0104     def analyze(self, event):
0105         """process event, return True (go to next module) or False (fail, go to next event)"""
0106         coll = [Collection(event, x) for x in self.input]
0107         objects = [(coll[j][i], j, i) for j in range(self.nInputs)
0108                    for i in range(len(coll[j]))]
0109         if self.selector:
0110             objects = [
0111                 obj_j_i for obj_j_i in objects
0112                 if self.selector[obj_j_i[1]](obj_j_i[0])
0113             ]
0114         objects.sort(key=self.sortkey, reverse=self.reverse)
0115         if self.maxObjects:
0116             objects = objects[:self.maxObjects]
0117         for bridx, br in enumerate(self.brlist_all):
0118             out = []
0119             for obj, j, i in objects:
0120                 out.append(getattr(obj, br) if self.is_there[bridx][j] else 0)
0121             self.out.fillBranch("%s_%s" % (self.output, br), out)
0122         return True
0123 
0124 
0125 # define modules using the syntax 'name = lambda : constructor' to avoid having them loaded when not needed
0126 
0127 lepMerger = lambda: collectionMerger(input=["Electron", "Muon"],
0128                                      output="Lepton")
0129 lepMerger_exampleSelection = lambda: collectionMerger(
0130     input=["Electron", "Muon"],
0131     output=
0132     "Lepton",  # this will keep only the two leading leptons among electrons with pt > 20 and muons with pt > 40
0133     maxObjects=2,
0134     selector=dict([("Electron", lambda x: x.pt > 20),
0135                    ("Muon", lambda x: x.pt > 40)]),
0136 )