Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:02:43

0001 from __future__ import print_function
0002 import itertools
0003 import ROOT
0004 try:
0005     ROOT.BTagEntry
0006 except AttributeError:
0007     ROOT.gROOT.ProcessLine('.L BTagCalibrationStandalone.cpp+')
0008 
0009 try:
0010     ROOT.BTagEntry
0011 except AttributeError:
0012     print('ROOT.BTagEntry is needed! Please copy ' \
0013           'BTagCalibrationStandalone.[h|cpp] to the working directory. Exit.')
0014     exit(-1)
0015 
0016 separate_by_op   = False
0017 separate_by_flav = False
0018 
0019 
0020 class DataLoader(object):
0021     def __init__(self, csv_data, measurement_type, operating_point, flavour):
0022         self.meas_type = measurement_type
0023         self.op = operating_point
0024         self.flav = flavour
0025 
0026         # list of entries
0027         ens = []
0028         for l in csv_data:
0029             if not l.strip():
0030                 continue  # skip empty lines
0031             try:
0032                 e = ROOT.BTagEntry(l)
0033                 if (e.params.measurementType == measurement_type
0034                     and ((not separate_by_op)
0035                             or e.params.operatingPoint == operating_point)
0036                     and ((not separate_by_flav)
0037                             or e.params.jetFlavor == flavour)
0038                 ):
0039                     ens.append(e)
0040             except TypeError:
0041                 raise RuntimeError("Error: can not interpret line: " + l)
0042         self.entries = ens
0043 
0044         if not ens:
0045             return
0046 
0047         # fixed data
0048         self.ops = set(e.params.operatingPoint for e in ens)
0049         self.flavs = set(e.params.jetFlavor for e in ens)
0050         self.syss = set(e.params.sysType for e in ens)
0051         self.etas = set((e.params.etaMin, e.params.etaMax) for e in ens)
0052         self.pts = set((e.params.ptMin, e.params.ptMax) for e in ens)
0053         self.discrs = set((e.params.discrMin, e.params.discrMax)
0054                           for e in ens
0055                           if e.params.operatingPoint == 3)
0056 
0057         self.ETA_MIN = -2.4
0058         self.ETA_MAX = 2.4
0059         self.PT_MIN = min(e.params.ptMin for e in ens)
0060         self.PT_MAX = max(e.params.ptMax for e in ens)
0061         if any(e.params.operatingPoint == 3 for e in ens):
0062             self.DISCR_MIN = min(
0063                 e.params.discrMin
0064                 for e in ens
0065                 if e.params.operatingPoint == 3
0066             )
0067             self.DISCR_MAX = max(
0068                 e.params.discrMax
0069                 for e in ens
0070                 if e.params.operatingPoint == 3
0071             )
0072         else:
0073             self.DISCR_MIN = 0.
0074             self.DISCR_MAX = 1.
0075 
0076         # test points for variable data (using bound +- epsilon)
0077         eps = 1e-4
0078         eta_test_points = list(itertools.ifilter(
0079             lambda x: self.ETA_MIN < x < self.ETA_MAX,
0080             itertools.chain(
0081                 (a + eps for a, _ in self.etas),
0082                 (a - eps for a, _ in self.etas),
0083                 (b + eps for _, b in self.etas),
0084                 (b - eps for _, b in self.etas),
0085                 (self.ETA_MIN + eps, self.ETA_MAX - eps),
0086             )
0087         ))
0088         abseta_test_points = list(itertools.ifilter(
0089             lambda x: 0. < x < self.ETA_MAX,
0090             itertools.chain(
0091                 (a + eps for a, _ in self.etas),
0092                 (a - eps for a, _ in self.etas),
0093                 (b + eps for _, b in self.etas),
0094                 (b - eps for _, b in self.etas),
0095                 (eps, self.ETA_MAX - eps),
0096             )
0097         ))
0098         pt_test_points = list(itertools.ifilter(
0099             lambda x: self.PT_MIN < x < self.PT_MAX,
0100             itertools.chain(
0101                 (a + eps for a, _ in self.pts),
0102                 (a - eps for a, _ in self.pts),
0103                 (b + eps for _, b in self.pts),
0104                 (b - eps for _, b in self.pts),
0105                 (self.PT_MIN + eps, self.PT_MAX - eps),
0106             )
0107         ))
0108         discr_test_points = list(itertools.ifilter(
0109             lambda x: self.DISCR_MIN < x < self.DISCR_MAX,
0110             itertools.chain(
0111                 (a + eps for a, _ in self.discrs),
0112                 (a - eps for a, _ in self.discrs),
0113                 (b + eps for _, b in self.discrs),
0114                 (b - eps for _, b in self.discrs),
0115                 (self.DISCR_MIN + eps, self.DISCR_MAX - eps),
0116             )
0117         ))
0118         # use sets
0119         self.eta_test_points = set(round(f, 5) for f in eta_test_points)
0120         self.abseta_test_points = set(round(f, 5) for f in abseta_test_points)
0121         self.pt_test_points = set(round(f, 5) for f in pt_test_points)
0122         self.discr_test_points = set(round(f, 5) for f in discr_test_points)
0123 
0124     def print_data(self):
0125         print("\nFound operating points:")
0126         print(self.ops)
0127 
0128         print("\nFound jet flavors:")
0129         print(self.flavs)
0130 
0131         print("\nFound sys types (need at least 'central', 'up', 'down'; " \
0132               "also 'up_SYS'/'down_SYS' compatibility is checked):")
0133         print(self.syss)
0134 
0135         print("\nFound eta ranges: (need everything covered from %g or 0. " \
0136               "up to %g):" % (self.ETA_MIN, self.ETA_MAX))
0137         print(self.etas)
0138 
0139         print("\nFound pt ranges: (need everything covered from %g " \
0140               "to %g):" % (self.PT_MIN, self.PT_MAX))
0141         print(self.pts)
0142 
0143         print("\nFound discr ranges: (only needed for operatingPoint==3, " \
0144               "covered from %g to %g):" % (self.DISCR_MIN, self.DISCR_MAX))
0145         print(self.discrs)
0146 
0147         print("\nTest points for eta (bounds +- epsilon):")
0148         print(self.eta_test_points)
0149 
0150         print("\nTest points for pt (bounds +- epsilon):")
0151         print(self.pt_test_points)
0152 
0153         print("\nTest points for discr (bounds +- epsilon):")
0154         print(self.discr_test_points)
0155         print("")
0156 
0157 
0158 def get_data_csv(csv_data):
0159     # grab measurement types
0160     meas_types = set(
0161         l.split(',')[1].strip()
0162         for l in csv_data
0163         if len(l.split()) == 11
0164     )
0165 
0166     # grab operating points
0167     ops = set(
0168         int(l.split(',')[0])
0169         for l in csv_data
0170         if len(l.split()) == 11
0171     ) if separate_by_op else ['all']
0172 
0173     # grab flavors
0174     flavs = set(
0175         int(l.split(',')[3])
0176         for l in csv_data
0177         if len(l.split()) == 11
0178     ) if separate_by_flav else ['all']
0179 
0180     # make loaders and filter empty ones
0181     lds = list(
0182         DataLoader(csv_data, mt, op, fl)
0183         for mt in meas_types
0184         for op in ops
0185         for fl in flavs
0186     )
0187     lds = [d for d in lds if d.entries]
0188     return lds
0189 
0190 
0191 def get_data(filename):
0192     with open(filename) as f:
0193         csv_data = f.readlines()
0194     if not (csv_data and "OperatingPoint" in csv_data[0]):
0195         print("Data file does not contain typical header: %s. Exit" % filename)
0196         return False
0197     csv_data.pop(0)  # remove header
0198     return get_data_csv(csv_data)