File indexing completed on 2023-03-17 10:47:44
0001 from __future__ import print_function
0002 import itertools
0003 import ROOT
0004 try:
0005 ROOT.BTagEntry
0006 except AttributeError:
0007 ROOT.gROOT.ProcessLine('.L BTagCalibrationStandalone.cpp+')
0008
0009 try:
0010 ROOT.BTagEntry
0011 except AttributeError:
0012 print('ROOT.BTagEntry is needed! Please copy ' \
0013 'BTagCalibrationStandalone.[h|cpp] to the working directory. Exit.')
0014 exit(-1)
0015
0016 separate_by_op = False
0017 separate_by_flav = False
0018
0019
0020 class DataLoader(object):
0021 def __init__(self, csv_data, measurement_type, operating_point, flavour):
0022 self.meas_type = measurement_type
0023 self.op = operating_point
0024 self.flav = flavour
0025
0026
0027 ens = []
0028 for l in csv_data:
0029 if not l.strip():
0030 continue
0031 try:
0032 e = ROOT.BTagEntry(l)
0033 if (e.params.measurementType == measurement_type
0034 and ((not separate_by_op)
0035 or e.params.operatingPoint == operating_point)
0036 and ((not separate_by_flav)
0037 or e.params.jetFlavor == flavour)
0038 ):
0039 ens.append(e)
0040 except TypeError:
0041 raise RuntimeError("Error: can not interpret line: " + l)
0042 self.entries = ens
0043
0044 if not ens:
0045 return
0046
0047
0048 self.ops = set(e.params.operatingPoint for e in ens)
0049 self.flavs = set(e.params.jetFlavor for e in ens)
0050 self.syss = set(e.params.sysType for e in ens)
0051 self.etas = set((e.params.etaMin, e.params.etaMax) for e in ens)
0052 self.pts = set((e.params.ptMin, e.params.ptMax) for e in ens)
0053 self.discrs = set((e.params.discrMin, e.params.discrMax)
0054 for e in ens
0055 if e.params.operatingPoint == 3)
0056
0057 self.ETA_MIN = -2.4
0058 self.ETA_MAX = 2.4
0059 self.PT_MIN = min(e.params.ptMin for e in ens)
0060 self.PT_MAX = max(e.params.ptMax for e in ens)
0061 if any(e.params.operatingPoint == 3 for e in ens):
0062 self.DISCR_MIN = min(
0063 e.params.discrMin
0064 for e in ens
0065 if e.params.operatingPoint == 3
0066 )
0067 self.DISCR_MAX = max(
0068 e.params.discrMax
0069 for e in ens
0070 if e.params.operatingPoint == 3
0071 )
0072 else:
0073 self.DISCR_MIN = 0.
0074 self.DISCR_MAX = 1.
0075
0076
0077 eps = 1e-4
0078 eta_test_points = list(itertools.ifilter(
0079 lambda x: self.ETA_MIN < x < self.ETA_MAX,
0080 itertools.chain(
0081 (a + eps for a, _ in self.etas),
0082 (a - eps for a, _ in self.etas),
0083 (b + eps for _, b in self.etas),
0084 (b - eps for _, b in self.etas),
0085 (self.ETA_MIN + eps, self.ETA_MAX - eps),
0086 )
0087 ))
0088 abseta_test_points = list(itertools.ifilter(
0089 lambda x: 0. < x < self.ETA_MAX,
0090 itertools.chain(
0091 (a + eps for a, _ in self.etas),
0092 (a - eps for a, _ in self.etas),
0093 (b + eps for _, b in self.etas),
0094 (b - eps for _, b in self.etas),
0095 (eps, self.ETA_MAX - eps),
0096 )
0097 ))
0098 pt_test_points = list(itertools.ifilter(
0099 lambda x: self.PT_MIN < x < self.PT_MAX,
0100 itertools.chain(
0101 (a + eps for a, _ in self.pts),
0102 (a - eps for a, _ in self.pts),
0103 (b + eps for _, b in self.pts),
0104 (b - eps for _, b in self.pts),
0105 (self.PT_MIN + eps, self.PT_MAX - eps),
0106 )
0107 ))
0108 discr_test_points = list(itertools.ifilter(
0109 lambda x: self.DISCR_MIN < x < self.DISCR_MAX,
0110 itertools.chain(
0111 (a + eps for a, _ in self.discrs),
0112 (a - eps for a, _ in self.discrs),
0113 (b + eps for _, b in self.discrs),
0114 (b - eps for _, b in self.discrs),
0115 (self.DISCR_MIN + eps, self.DISCR_MAX - eps),
0116 )
0117 ))
0118
0119 self.eta_test_points = set(round(f, 5) for f in eta_test_points)
0120 self.abseta_test_points = set(round(f, 5) for f in abseta_test_points)
0121 self.pt_test_points = set(round(f, 5) for f in pt_test_points)
0122 self.discr_test_points = set(round(f, 5) for f in discr_test_points)
0123
0124 def print_data(self):
0125 print("\nFound operating points:")
0126 print(self.ops)
0127
0128 print("\nFound jet flavors:")
0129 print(self.flavs)
0130
0131 print("\nFound sys types (need at least 'central', 'up', 'down'; " \
0132 "also 'up_SYS'/'down_SYS' compatibility is checked):")
0133 print(self.syss)
0134
0135 print("\nFound eta ranges: (need everything covered from %g or 0. " \
0136 "up to %g):" % (self.ETA_MIN, self.ETA_MAX))
0137 print(self.etas)
0138
0139 print("\nFound pt ranges: (need everything covered from %g " \
0140 "to %g):" % (self.PT_MIN, self.PT_MAX))
0141 print(self.pts)
0142
0143 print("\nFound discr ranges: (only needed for operatingPoint==3, " \
0144 "covered from %g to %g):" % (self.DISCR_MIN, self.DISCR_MAX))
0145 print(self.discrs)
0146
0147 print("\nTest points for eta (bounds +- epsilon):")
0148 print(self.eta_test_points)
0149
0150 print("\nTest points for pt (bounds +- epsilon):")
0151 print(self.pt_test_points)
0152
0153 print("\nTest points for discr (bounds +- epsilon):")
0154 print(self.discr_test_points)
0155 print("")
0156
0157
0158 def get_data_csv(csv_data):
0159
0160 meas_types = set(
0161 l.split(',')[1].strip()
0162 for l in csv_data
0163 if len(l.split()) == 11
0164 )
0165
0166
0167 ops = set(
0168 int(l.split(',')[0])
0169 for l in csv_data
0170 if len(l.split()) == 11
0171 ) if separate_by_op else ['all']
0172
0173
0174 flavs = set(
0175 int(l.split(',')[3])
0176 for l in csv_data
0177 if len(l.split()) == 11
0178 ) if separate_by_flav else ['all']
0179
0180
0181 lds = list(
0182 DataLoader(csv_data, mt, op, fl)
0183 for mt in meas_types
0184 for op in ops
0185 for fl in flavs
0186 )
0187 lds = [d for d in lds if d.entries]
0188 return lds
0189
0190
0191 def get_data(filename):
0192 with open(filename) as f:
0193 csv_data = f.readlines()
0194 if not (csv_data and "OperatingPoint" in csv_data[0]):
0195 print("Data file does not contain typical header: %s. Exit" % filename)
0196 return False
0197 csv_data.pop(0)
0198 return get_data_csv(csv_data)