File indexing completed on 2024-11-25 02:29:23
0001 import itertools
0002 import ROOT
0003 try:
0004 ROOT.BTagEntry
0005 except AttributeError:
0006 ROOT.gROOT.ProcessLine('.L BTagCalibrationStandalone.cpp+')
0007
0008 try:
0009 ROOT.BTagEntry
0010 except AttributeError:
0011 print('ROOT.BTagEntry is needed! Please copy ' \
0012 'BTagCalibrationStandalone.[h|cpp] to the working directory. Exit.')
0013 exit(-1)
0014
0015 separate_by_op = False
0016 separate_by_flav = False
0017
0018
0019 class DataLoader(object):
0020 def __init__(self, csv_data, measurement_type, operating_point, flavour):
0021 self.meas_type = measurement_type
0022 self.op = operating_point
0023 self.flav = flavour
0024
0025
0026 ens = []
0027 for l in csv_data:
0028 if not l.strip():
0029 continue
0030 try:
0031 e = ROOT.BTagEntry(l)
0032 if (e.params.measurementType == measurement_type
0033 and ((not separate_by_op)
0034 or e.params.operatingPoint == operating_point)
0035 and ((not separate_by_flav)
0036 or e.params.jetFlavor == flavour)
0037 ):
0038 ens.append(e)
0039 except TypeError:
0040 raise RuntimeError("Error: can not interpret line: " + l)
0041 self.entries = ens
0042
0043 if not ens:
0044 return
0045
0046
0047 self.ops = set(e.params.operatingPoint for e in ens)
0048 self.flavs = set(e.params.jetFlavor for e in ens)
0049 self.syss = set(e.params.sysType for e in ens)
0050 self.etas = set((e.params.etaMin, e.params.etaMax) for e in ens)
0051 self.pts = set((e.params.ptMin, e.params.ptMax) for e in ens)
0052 self.discrs = set((e.params.discrMin, e.params.discrMax)
0053 for e in ens
0054 if e.params.operatingPoint == 3)
0055
0056 self.ETA_MIN = -2.4
0057 self.ETA_MAX = 2.4
0058 self.PT_MIN = min(e.params.ptMin for e in ens)
0059 self.PT_MAX = max(e.params.ptMax for e in ens)
0060 if any(e.params.operatingPoint == 3 for e in ens):
0061 self.DISCR_MIN = min(
0062 e.params.discrMin
0063 for e in ens
0064 if e.params.operatingPoint == 3
0065 )
0066 self.DISCR_MAX = max(
0067 e.params.discrMax
0068 for e in ens
0069 if e.params.operatingPoint == 3
0070 )
0071 else:
0072 self.DISCR_MIN = 0.
0073 self.DISCR_MAX = 1.
0074
0075
0076 eps = 1e-4
0077 eta_test_points = list(itertools.ifilter(
0078 lambda x: self.ETA_MIN < x < self.ETA_MAX,
0079 itertools.chain(
0080 (a + eps for a, _ in self.etas),
0081 (a - eps for a, _ in self.etas),
0082 (b + eps for _, b in self.etas),
0083 (b - eps for _, b in self.etas),
0084 (self.ETA_MIN + eps, self.ETA_MAX - eps),
0085 )
0086 ))
0087 abseta_test_points = list(itertools.ifilter(
0088 lambda x: 0. < x < self.ETA_MAX,
0089 itertools.chain(
0090 (a + eps for a, _ in self.etas),
0091 (a - eps for a, _ in self.etas),
0092 (b + eps for _, b in self.etas),
0093 (b - eps for _, b in self.etas),
0094 (eps, self.ETA_MAX - eps),
0095 )
0096 ))
0097 pt_test_points = list(itertools.ifilter(
0098 lambda x: self.PT_MIN < x < self.PT_MAX,
0099 itertools.chain(
0100 (a + eps for a, _ in self.pts),
0101 (a - eps for a, _ in self.pts),
0102 (b + eps for _, b in self.pts),
0103 (b - eps for _, b in self.pts),
0104 (self.PT_MIN + eps, self.PT_MAX - eps),
0105 )
0106 ))
0107 discr_test_points = list(itertools.ifilter(
0108 lambda x: self.DISCR_MIN < x < self.DISCR_MAX,
0109 itertools.chain(
0110 (a + eps for a, _ in self.discrs),
0111 (a - eps for a, _ in self.discrs),
0112 (b + eps for _, b in self.discrs),
0113 (b - eps for _, b in self.discrs),
0114 (self.DISCR_MIN + eps, self.DISCR_MAX - eps),
0115 )
0116 ))
0117
0118 self.eta_test_points = set(round(f, 5) for f in eta_test_points)
0119 self.abseta_test_points = set(round(f, 5) for f in abseta_test_points)
0120 self.pt_test_points = set(round(f, 5) for f in pt_test_points)
0121 self.discr_test_points = set(round(f, 5) for f in discr_test_points)
0122
0123 def print_data(self):
0124 print("\nFound operating points:")
0125 print(self.ops)
0126
0127 print("\nFound jet flavors:")
0128 print(self.flavs)
0129
0130 print("\nFound sys types (need at least 'central', 'up', 'down'; " \
0131 "also 'up_SYS'/'down_SYS' compatibility is checked):")
0132 print(self.syss)
0133
0134 print("\nFound eta ranges: (need everything covered from %g or 0. " \
0135 "up to %g):" % (self.ETA_MIN, self.ETA_MAX))
0136 print(self.etas)
0137
0138 print("\nFound pt ranges: (need everything covered from %g " \
0139 "to %g):" % (self.PT_MIN, self.PT_MAX))
0140 print(self.pts)
0141
0142 print("\nFound discr ranges: (only needed for operatingPoint==3, " \
0143 "covered from %g to %g):" % (self.DISCR_MIN, self.DISCR_MAX))
0144 print(self.discrs)
0145
0146 print("\nTest points for eta (bounds +- epsilon):")
0147 print(self.eta_test_points)
0148
0149 print("\nTest points for pt (bounds +- epsilon):")
0150 print(self.pt_test_points)
0151
0152 print("\nTest points for discr (bounds +- epsilon):")
0153 print(self.discr_test_points)
0154 print("")
0155
0156
0157 def get_data_csv(csv_data):
0158
0159 meas_types = set(
0160 l.split(',')[1].strip()
0161 for l in csv_data
0162 if len(l.split()) == 11
0163 )
0164
0165
0166 ops = set(
0167 int(l.split(',')[0])
0168 for l in csv_data
0169 if len(l.split()) == 11
0170 ) if separate_by_op else ['all']
0171
0172
0173 flavs = set(
0174 int(l.split(',')[3])
0175 for l in csv_data
0176 if len(l.split()) == 11
0177 ) if separate_by_flav else ['all']
0178
0179
0180 lds = list(
0181 DataLoader(csv_data, mt, op, fl)
0182 for mt in meas_types
0183 for op in ops
0184 for fl in flavs
0185 )
0186 lds = [d for d in lds if d.entries]
0187 return lds
0188
0189
0190 def get_data(filename):
0191 with open(filename) as f:
0192 csv_data = f.readlines()
0193 if not (csv_data and "OperatingPoint" in csv_data[0]):
0194 print("Data file does not contain typical header: %s. Exit" % filename)
0195 return False
0196 csv_data.pop(0)
0197 return get_data_csv(csv_data)