Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-11-25 02:29:23

0001 import itertools
0002 import ROOT
0003 try:
0004     ROOT.BTagEntry
0005 except AttributeError:
0006     ROOT.gROOT.ProcessLine('.L BTagCalibrationStandalone.cpp+')
0007 
0008 try:
0009     ROOT.BTagEntry
0010 except AttributeError:
0011     print('ROOT.BTagEntry is needed! Please copy ' \
0012           'BTagCalibrationStandalone.[h|cpp] to the working directory. Exit.')
0013     exit(-1)
0014 
0015 separate_by_op   = False
0016 separate_by_flav = False
0017 
0018 
0019 class DataLoader(object):
0020     def __init__(self, csv_data, measurement_type, operating_point, flavour):
0021         self.meas_type = measurement_type
0022         self.op = operating_point
0023         self.flav = flavour
0024 
0025         # list of entries
0026         ens = []
0027         for l in csv_data:
0028             if not l.strip():
0029                 continue  # skip empty lines
0030             try:
0031                 e = ROOT.BTagEntry(l)
0032                 if (e.params.measurementType == measurement_type
0033                     and ((not separate_by_op)
0034                             or e.params.operatingPoint == operating_point)
0035                     and ((not separate_by_flav)
0036                             or e.params.jetFlavor == flavour)
0037                 ):
0038                     ens.append(e)
0039             except TypeError:
0040                 raise RuntimeError("Error: can not interpret line: " + l)
0041         self.entries = ens
0042 
0043         if not ens:
0044             return
0045 
0046         # fixed data
0047         self.ops = set(e.params.operatingPoint for e in ens)
0048         self.flavs = set(e.params.jetFlavor for e in ens)
0049         self.syss = set(e.params.sysType for e in ens)
0050         self.etas = set((e.params.etaMin, e.params.etaMax) for e in ens)
0051         self.pts = set((e.params.ptMin, e.params.ptMax) for e in ens)
0052         self.discrs = set((e.params.discrMin, e.params.discrMax)
0053                           for e in ens
0054                           if e.params.operatingPoint == 3)
0055 
0056         self.ETA_MIN = -2.4
0057         self.ETA_MAX = 2.4
0058         self.PT_MIN = min(e.params.ptMin for e in ens)
0059         self.PT_MAX = max(e.params.ptMax for e in ens)
0060         if any(e.params.operatingPoint == 3 for e in ens):
0061             self.DISCR_MIN = min(
0062                 e.params.discrMin
0063                 for e in ens
0064                 if e.params.operatingPoint == 3
0065             )
0066             self.DISCR_MAX = max(
0067                 e.params.discrMax
0068                 for e in ens
0069                 if e.params.operatingPoint == 3
0070             )
0071         else:
0072             self.DISCR_MIN = 0.
0073             self.DISCR_MAX = 1.
0074 
0075         # test points for variable data (using bound +- epsilon)
0076         eps = 1e-4
0077         eta_test_points = list(itertools.ifilter(
0078             lambda x: self.ETA_MIN < x < self.ETA_MAX,
0079             itertools.chain(
0080                 (a + eps for a, _ in self.etas),
0081                 (a - eps for a, _ in self.etas),
0082                 (b + eps for _, b in self.etas),
0083                 (b - eps for _, b in self.etas),
0084                 (self.ETA_MIN + eps, self.ETA_MAX - eps),
0085             )
0086         ))
0087         abseta_test_points = list(itertools.ifilter(
0088             lambda x: 0. < x < self.ETA_MAX,
0089             itertools.chain(
0090                 (a + eps for a, _ in self.etas),
0091                 (a - eps for a, _ in self.etas),
0092                 (b + eps for _, b in self.etas),
0093                 (b - eps for _, b in self.etas),
0094                 (eps, self.ETA_MAX - eps),
0095             )
0096         ))
0097         pt_test_points = list(itertools.ifilter(
0098             lambda x: self.PT_MIN < x < self.PT_MAX,
0099             itertools.chain(
0100                 (a + eps for a, _ in self.pts),
0101                 (a - eps for a, _ in self.pts),
0102                 (b + eps for _, b in self.pts),
0103                 (b - eps for _, b in self.pts),
0104                 (self.PT_MIN + eps, self.PT_MAX - eps),
0105             )
0106         ))
0107         discr_test_points = list(itertools.ifilter(
0108             lambda x: self.DISCR_MIN < x < self.DISCR_MAX,
0109             itertools.chain(
0110                 (a + eps for a, _ in self.discrs),
0111                 (a - eps for a, _ in self.discrs),
0112                 (b + eps for _, b in self.discrs),
0113                 (b - eps for _, b in self.discrs),
0114                 (self.DISCR_MIN + eps, self.DISCR_MAX - eps),
0115             )
0116         ))
0117         # use sets
0118         self.eta_test_points = set(round(f, 5) for f in eta_test_points)
0119         self.abseta_test_points = set(round(f, 5) for f in abseta_test_points)
0120         self.pt_test_points = set(round(f, 5) for f in pt_test_points)
0121         self.discr_test_points = set(round(f, 5) for f in discr_test_points)
0122 
0123     def print_data(self):
0124         print("\nFound operating points:")
0125         print(self.ops)
0126 
0127         print("\nFound jet flavors:")
0128         print(self.flavs)
0129 
0130         print("\nFound sys types (need at least 'central', 'up', 'down'; " \
0131               "also 'up_SYS'/'down_SYS' compatibility is checked):")
0132         print(self.syss)
0133 
0134         print("\nFound eta ranges: (need everything covered from %g or 0. " \
0135               "up to %g):" % (self.ETA_MIN, self.ETA_MAX))
0136         print(self.etas)
0137 
0138         print("\nFound pt ranges: (need everything covered from %g " \
0139               "to %g):" % (self.PT_MIN, self.PT_MAX))
0140         print(self.pts)
0141 
0142         print("\nFound discr ranges: (only needed for operatingPoint==3, " \
0143               "covered from %g to %g):" % (self.DISCR_MIN, self.DISCR_MAX))
0144         print(self.discrs)
0145 
0146         print("\nTest points for eta (bounds +- epsilon):")
0147         print(self.eta_test_points)
0148 
0149         print("\nTest points for pt (bounds +- epsilon):")
0150         print(self.pt_test_points)
0151 
0152         print("\nTest points for discr (bounds +- epsilon):")
0153         print(self.discr_test_points)
0154         print("")
0155 
0156 
0157 def get_data_csv(csv_data):
0158     # grab measurement types
0159     meas_types = set(
0160         l.split(',')[1].strip()
0161         for l in csv_data
0162         if len(l.split()) == 11
0163     )
0164 
0165     # grab operating points
0166     ops = set(
0167         int(l.split(',')[0])
0168         for l in csv_data
0169         if len(l.split()) == 11
0170     ) if separate_by_op else ['all']
0171 
0172     # grab flavors
0173     flavs = set(
0174         int(l.split(',')[3])
0175         for l in csv_data
0176         if len(l.split()) == 11
0177     ) if separate_by_flav else ['all']
0178 
0179     # make loaders and filter empty ones
0180     lds = list(
0181         DataLoader(csv_data, mt, op, fl)
0182         for mt in meas_types
0183         for op in ops
0184         for fl in flavs
0185     )
0186     lds = [d for d in lds if d.entries]
0187     return lds
0188 
0189 
0190 def get_data(filename):
0191     with open(filename) as f:
0192         csv_data = f.readlines()
0193     if not (csv_data and "OperatingPoint" in csv_data[0]):
0194         print("Data file does not contain typical header: %s. Exit" % filename)
0195         return False
0196     csv_data.pop(0)  # remove header
0197     return get_data_csv(csv_data)