Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-11-25 02:29:23

0001 #!/usr/bin/env python
0002 
0003 import itertools
0004 import unittest
0005 import sys
0006 from . import dataLoader
0007 import ROOT
0008 
0009 
0010 data = None
0011 check_flavor = True
0012 check_op = True
0013 check_sys = True
0014 verbose = False
0015 
0016 
0017 def _eta_pt_discr_entries_generator(filter_keyfunc, op):
0018     assert data
0019     entries = list(filter(filter_keyfunc, data.entries))
0020 
0021     # use full or half eta range?
0022     if any(e.params.etaMin < 0. for e in entries):
0023         eta_test_points = data.eta_test_points
0024     else:
0025         eta_test_points = data.abseta_test_points
0026 
0027     for eta in eta_test_points:
0028         for pt in data.pt_test_points:
0029             ens_pt_eta = [e for e in entries if e.params.etaMin < eta < e.params.etaMax and
0030                 e.params.ptMin < pt < e.params.ptMax]
0031             if op == 3:
0032                 for discr in data.discr_test_points:
0033                     ens_pt_eta_discr = [e for e in ens_pt_eta if e.params.discrMin < discr < e.params.discrMax]
0034                     yield eta, pt, discr, ens_pt_eta_discr
0035             else:
0036                 yield eta, pt, None, ens_pt_eta
0037 
0038 
0039 class BtagCalibConsistencyChecker(unittest.TestCase):
0040     def test_lowercase(self):
0041         for item in [data.meas_type] + list(data.syss):
0042             self.assertEqual(
0043                 item, item.lower(),
0044                 "Item is not lowercase: %s" % item
0045             )
0046 
0047     def test_ops_tight(self):
0048         if check_op:
0049             self.assertIn(2, data.ops, "OP_TIGHT is missing")
0050 
0051     def test_ops_medium(self):
0052         if check_op:
0053             self.assertIn(1, data.ops, "OP_MEDIUM is missing")
0054 
0055     def test_ops_loose(self):
0056         if check_op:
0057             self.assertIn(0, data.ops, "OP_LOOSE is missing")
0058 
0059     def test_flavs_b(self):
0060         if check_flavor:
0061             self.assertIn(0, data.flavs, "FLAV_B is missing")
0062 
0063     def test_flavs_c(self):
0064         if check_flavor:
0065             self.assertIn(1, data.flavs, "FLAV_C is missing")
0066 
0067     def test_flavs_udsg(self):
0068         if check_flavor:
0069             self.assertIn(2, data.flavs, "FLAV_UDSG is missing")
0070 
0071     def test_systematics_central(self):
0072         if check_sys:
0073             self.assertIn("central", data.syss,
0074                           "'central' sys. uncert. is missing")
0075 
0076     def test_systematics_up(self):
0077         if check_sys:
0078             self.assertIn("up", data.syss, "'up' sys. uncert. is missing")
0079 
0080     def test_systematics_down(self):
0081         if check_sys:
0082             self.assertIn("down", data.syss, "'down' sys. uncert. is missing")
0083 
0084     def test_systematics_name(self):
0085         if check_sys:
0086             for syst in data.syss:
0087                 if syst == 'central':
0088                     continue
0089                 self.assertTrue(
0090                     syst.startswith("up") or syst.startswith("down"),
0091                     "sys. uncert name must start with 'up' or 'down' : %s"
0092                     % syst
0093                 )
0094 
0095     def test_systematics_doublesidedness(self):
0096         if check_sys:
0097             for syst in data.syss:
0098                 if "up" in syst:
0099                     other = syst.replace("up", "down")
0100                     self.assertIn(other, data.syss,
0101                                   "'%s' sys. uncert. is missing" % other)
0102                 elif "down" in syst:
0103                     other = syst.replace("down", "up")
0104                     self.assertIn(other, data.syss,
0105                                   "'%s' sys. uncert. is missing" % other)
0106 
0107     def test_systematics_values_vs_centrals(self):
0108         if check_sys:
0109             res = list(itertools.chain.from_iterable(
0110                 self._check_sys_side(op, flav)
0111                 for flav in data.flavs
0112                 for op in data.ops
0113             ))
0114             self.assertFalse(bool(res), "\n"+"\n".join(res))
0115 
0116     def _check_sys_side(self, op, flav):
0117         region = "op=%d, flav=%d" % (op, flav)
0118         if verbose:
0119             print("Checking sys side correctness for", region)
0120 
0121         res = []
0122         for eta, pt, discr, entries in _eta_pt_discr_entries_generator(
0123             lambda e:
0124             e.params.operatingPoint == op and
0125             e.params.jetFlavor == flav,
0126             op
0127         ):
0128             if not entries:
0129                 continue
0130 
0131             for e in entries:  # do a little monkey patching with tf1's
0132                 if not hasattr(e, 'tf1_func'):
0133                     e.tf1_func = ROOT.TF1("", e.formula)
0134 
0135             sys_dict = dict((e.params.sysType, e) for e in entries)
0136             assert len(sys_dict) == len(entries)
0137             sys_cent = sys_dict.pop('central', None)
0138             x = discr if op == 3 else pt
0139             for syst, e in sys_dict.items():
0140                 sys_val = e.tf1_func.Eval(x)
0141                 cent_val = sys_cent.tf1_func.Eval(x)
0142                 if syst.startswith('up') and not sys_val > cent_val:
0143                     res.append(
0144                         ("Up variation '%s' not larger than 'central': %s "
0145                          "eta=%f, pt=%f " % (syst, region, eta, pt))
0146                         + ((", discr=%f" % discr) if discr else "")
0147                     )
0148                 elif syst.startswith('down') and not sys_val < cent_val:
0149                     res.append(
0150                         ("Down variation '%s' not smaller than 'central': %s "
0151                          "eta=%f, pt=%f " % (syst, region, eta, pt))
0152                         + ((", discr=%f" % discr) if discr else "")
0153                     )
0154         return res
0155 
0156     def test_eta_ranges(self):
0157         for a, b in data.etas:
0158             self.assertLess(a, b)
0159             self.assertGreater(a, data.ETA_MIN - 1e-7)
0160             self.assertLess(b, data.ETA_MAX + 1e-7)
0161 
0162     def test_pt_ranges(self):
0163         for a, b in data.pts:
0164             self.assertLess(a, b)
0165             self.assertGreater(a, data.PT_MIN - 1e-7)
0166             self.assertLess(b, data.PT_MAX + 1e-7)
0167 
0168     def test_discr_ranges(self):
0169         for a, b in data.discrs:
0170             self.assertLess(a, b)
0171             self.assertGreater(a, data.DISCR_MIN - 1e-7)
0172             self.assertLess(b, data.DISCR_MAX + 1e-7)
0173 
0174     def test_coverage(self):
0175         res = list(itertools.chain.from_iterable(
0176             self._check_coverage(op, syst, flav)
0177             for flav in data.flavs
0178             for syst in data.syss
0179             for op in data.ops
0180         ))
0181         self.assertFalse(bool(res), "\n"+"\n".join(res))
0182 
0183     def _check_coverage(self, op, syst, flav):
0184         region = "op=%d, %s, flav=%d" % (op, syst, flav)
0185         if verbose:
0186             print("Checking coverage for", region)
0187 
0188         # walk over all testpoints
0189         res = []
0190         for eta, pt, discr, entries in _eta_pt_discr_entries_generator(
0191             lambda e:
0192             e.params.operatingPoint == op and
0193             e.params.sysType == syst and
0194             e.params.jetFlavor == flav,
0195             op
0196         ):
0197             size = len(entries)
0198             if size == 0:
0199                 res.append(
0200                     ("Region not covered: %s eta=%f, pt=%f "
0201                      % (region, eta, pt))
0202                     + ((", discr=%f" % discr) if discr else "")
0203                 )
0204             elif size > 1:
0205                 res.append(
0206                     ("Region covered %d times: %s eta=%f, pt=%f"
0207                      % (size, region, eta, pt))
0208                     + ((", discr=%f" % discr) if discr else "")
0209                 )
0210         return res
0211 
0212 
0213 def run_check(filename, op=True, sys=True, flavor=True):
0214     loaders = dataLoader.get_data(filename)
0215     return run_check_data(loaders, op, sys, flavor)
0216 
0217 
0218 def run_check_csv(csv_data, op=True, sys=True, flavor=True):
0219     loaders = dataLoader.get_data_csv(csv_data)
0220     return run_check_data(loaders, op, sys, flavor)
0221 
0222 
0223 def run_check_data(data_loaders,
0224                    op=True, sys=True, flavor=True):
0225     global data, check_op, check_sys, check_flavor
0226     check_op, check_sys, check_flavor = op, sys, flavor
0227 
0228     all_res = []
0229     for dat in data_loaders:
0230         data = dat
0231         print('\n\n')
0232         print('# Checking csv data for type / op / flavour:', \
0233             data.meas_type, data.op, data.flav)
0234         print('='*60 + '\n')
0235         if verbose:
0236             data.print_data()
0237         testsuite = unittest.TestLoader().loadTestsFromTestCase(
0238             BtagCalibConsistencyChecker)
0239         res = unittest.TextTestRunner().run(testsuite)
0240         all_res.append(not bool(res.failures))
0241     return all_res
0242 
0243 
0244 if __name__ == '__main__':
0245     if len(sys.argv) < 2:
0246         print('Need csv data file as first argument.')
0247         print('Options:')
0248         print('    --light (do not check op, sys, flav)')
0249         print('    --separate-by-op')
0250         print('    --separate-by-flav')
0251         print('    --separate-all (both of the above)')
0252         print('Exit.')
0253         exit(-1)
0254 
0255     ck_op = ck_sy = ck_fl = not '--light' in sys.argv
0256 
0257     dataLoader.separate_by_op   = '--separate-by-op'   in sys.argv
0258     dataLoader.separate_by_flav = '--separate-by-flav' in sys.argv
0259 
0260     if '--separate-all' in sys.argv:
0261         dataLoader.separate_by_op = dataLoader.separate_by_flav = True
0262 
0263     if dataLoader.separate_by_op:
0264         ck_op = False
0265     if dataLoader.separate_by_flav:
0266         ck_fl = False
0267 
0268     verbose = True
0269     if not all(run_check(sys.argv[1], ck_op, ck_sy, ck_fl)):
0270         exit(-1)
0271