Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:02:43

0001 #!/usr/bin/env python
0002 
0003 from __future__ import print_function
0004 from __future__ import absolute_import
0005 import itertools
0006 import unittest
0007 import sys
0008 from . import dataLoader
0009 import ROOT
0010 
0011 
0012 data = None
0013 check_flavor = True
0014 check_op = True
0015 check_sys = True
0016 verbose = False
0017 
0018 
0019 def _eta_pt_discr_entries_generator(filter_keyfunc, op):
0020     assert data
0021     entries = list(filter(filter_keyfunc, data.entries))
0022 
0023     # use full or half eta range?
0024     if any(e.params.etaMin < 0. for e in entries):
0025         eta_test_points = data.eta_test_points
0026     else:
0027         eta_test_points = data.abseta_test_points
0028 
0029     for eta in eta_test_points:
0030         for pt in data.pt_test_points:
0031             ens_pt_eta = [e for e in entries if e.params.etaMin < eta < e.params.etaMax and
0032                 e.params.ptMin < pt < e.params.ptMax]
0033             if op == 3:
0034                 for discr in data.discr_test_points:
0035                     ens_pt_eta_discr = [e for e in ens_pt_eta if e.params.discrMin < discr < e.params.discrMax]
0036                     yield eta, pt, discr, ens_pt_eta_discr
0037             else:
0038                 yield eta, pt, None, ens_pt_eta
0039 
0040 
0041 class BtagCalibConsistencyChecker(unittest.TestCase):
0042     def test_lowercase(self):
0043         for item in [data.meas_type] + list(data.syss):
0044             self.assertEqual(
0045                 item, item.lower(),
0046                 "Item is not lowercase: %s" % item
0047             )
0048 
0049     def test_ops_tight(self):
0050         if check_op:
0051             self.assertIn(2, data.ops, "OP_TIGHT is missing")
0052 
0053     def test_ops_medium(self):
0054         if check_op:
0055             self.assertIn(1, data.ops, "OP_MEDIUM is missing")
0056 
0057     def test_ops_loose(self):
0058         if check_op:
0059             self.assertIn(0, data.ops, "OP_LOOSE is missing")
0060 
0061     def test_flavs_b(self):
0062         if check_flavor:
0063             self.assertIn(0, data.flavs, "FLAV_B is missing")
0064 
0065     def test_flavs_c(self):
0066         if check_flavor:
0067             self.assertIn(1, data.flavs, "FLAV_C is missing")
0068 
0069     def test_flavs_udsg(self):
0070         if check_flavor:
0071             self.assertIn(2, data.flavs, "FLAV_UDSG is missing")
0072 
0073     def test_systematics_central(self):
0074         if check_sys:
0075             self.assertIn("central", data.syss,
0076                           "'central' sys. uncert. is missing")
0077 
0078     def test_systematics_up(self):
0079         if check_sys:
0080             self.assertIn("up", data.syss, "'up' sys. uncert. is missing")
0081 
0082     def test_systematics_down(self):
0083         if check_sys:
0084             self.assertIn("down", data.syss, "'down' sys. uncert. is missing")
0085 
0086     def test_systematics_name(self):
0087         if check_sys:
0088             for syst in data.syss:
0089                 if syst == 'central':
0090                     continue
0091                 self.assertTrue(
0092                     syst.startswith("up") or syst.startswith("down"),
0093                     "sys. uncert name must start with 'up' or 'down' : %s"
0094                     % syst
0095                 )
0096 
0097     def test_systematics_doublesidedness(self):
0098         if check_sys:
0099             for syst in data.syss:
0100                 if "up" in syst:
0101                     other = syst.replace("up", "down")
0102                     self.assertIn(other, data.syss,
0103                                   "'%s' sys. uncert. is missing" % other)
0104                 elif "down" in syst:
0105                     other = syst.replace("down", "up")
0106                     self.assertIn(other, data.syss,
0107                                   "'%s' sys. uncert. is missing" % other)
0108 
0109     def test_systematics_values_vs_centrals(self):
0110         if check_sys:
0111             res = list(itertools.chain.from_iterable(
0112                 self._check_sys_side(op, flav)
0113                 for flav in data.flavs
0114                 for op in data.ops
0115             ))
0116             self.assertFalse(bool(res), "\n"+"\n".join(res))
0117 
0118     def _check_sys_side(self, op, flav):
0119         region = "op=%d, flav=%d" % (op, flav)
0120         if verbose:
0121             print("Checking sys side correctness for", region)
0122 
0123         res = []
0124         for eta, pt, discr, entries in _eta_pt_discr_entries_generator(
0125             lambda e:
0126             e.params.operatingPoint == op and
0127             e.params.jetFlavor == flav,
0128             op
0129         ):
0130             if not entries:
0131                 continue
0132 
0133             for e in entries:  # do a little monkey patching with tf1's
0134                 if not hasattr(e, 'tf1_func'):
0135                     e.tf1_func = ROOT.TF1("", e.formula)
0136 
0137             sys_dict = dict((e.params.sysType, e) for e in entries)
0138             assert len(sys_dict) == len(entries)
0139             sys_cent = sys_dict.pop('central', None)
0140             x = discr if op == 3 else pt
0141             for syst, e in sys_dict.items():
0142                 sys_val = e.tf1_func.Eval(x)
0143                 cent_val = sys_cent.tf1_func.Eval(x)
0144                 if syst.startswith('up') and not sys_val > cent_val:
0145                     res.append(
0146                         ("Up variation '%s' not larger than 'central': %s "
0147                          "eta=%f, pt=%f " % (syst, region, eta, pt))
0148                         + ((", discr=%f" % discr) if discr else "")
0149                     )
0150                 elif syst.startswith('down') and not sys_val < cent_val:
0151                     res.append(
0152                         ("Down variation '%s' not smaller than 'central': %s "
0153                          "eta=%f, pt=%f " % (syst, region, eta, pt))
0154                         + ((", discr=%f" % discr) if discr else "")
0155                     )
0156         return res
0157 
0158     def test_eta_ranges(self):
0159         for a, b in data.etas:
0160             self.assertLess(a, b)
0161             self.assertGreater(a, data.ETA_MIN - 1e-7)
0162             self.assertLess(b, data.ETA_MAX + 1e-7)
0163 
0164     def test_pt_ranges(self):
0165         for a, b in data.pts:
0166             self.assertLess(a, b)
0167             self.assertGreater(a, data.PT_MIN - 1e-7)
0168             self.assertLess(b, data.PT_MAX + 1e-7)
0169 
0170     def test_discr_ranges(self):
0171         for a, b in data.discrs:
0172             self.assertLess(a, b)
0173             self.assertGreater(a, data.DISCR_MIN - 1e-7)
0174             self.assertLess(b, data.DISCR_MAX + 1e-7)
0175 
0176     def test_coverage(self):
0177         res = list(itertools.chain.from_iterable(
0178             self._check_coverage(op, syst, flav)
0179             for flav in data.flavs
0180             for syst in data.syss
0181             for op in data.ops
0182         ))
0183         self.assertFalse(bool(res), "\n"+"\n".join(res))
0184 
0185     def _check_coverage(self, op, syst, flav):
0186         region = "op=%d, %s, flav=%d" % (op, syst, flav)
0187         if verbose:
0188             print("Checking coverage for", region)
0189 
0190         # walk over all testpoints
0191         res = []
0192         for eta, pt, discr, entries in _eta_pt_discr_entries_generator(
0193             lambda e:
0194             e.params.operatingPoint == op and
0195             e.params.sysType == syst and
0196             e.params.jetFlavor == flav,
0197             op
0198         ):
0199             size = len(entries)
0200             if size == 0:
0201                 res.append(
0202                     ("Region not covered: %s eta=%f, pt=%f "
0203                      % (region, eta, pt))
0204                     + ((", discr=%f" % discr) if discr else "")
0205                 )
0206             elif size > 1:
0207                 res.append(
0208                     ("Region covered %d times: %s eta=%f, pt=%f"
0209                      % (size, region, eta, pt))
0210                     + ((", discr=%f" % discr) if discr else "")
0211                 )
0212         return res
0213 
0214 
0215 def run_check(filename, op=True, sys=True, flavor=True):
0216     loaders = dataLoader.get_data(filename)
0217     return run_check_data(loaders, op, sys, flavor)
0218 
0219 
0220 def run_check_csv(csv_data, op=True, sys=True, flavor=True):
0221     loaders = dataLoader.get_data_csv(csv_data)
0222     return run_check_data(loaders, op, sys, flavor)
0223 
0224 
0225 def run_check_data(data_loaders,
0226                    op=True, sys=True, flavor=True):
0227     global data, check_op, check_sys, check_flavor
0228     check_op, check_sys, check_flavor = op, sys, flavor
0229 
0230     all_res = []
0231     for dat in data_loaders:
0232         data = dat
0233         print('\n\n')
0234         print('# Checking csv data for type / op / flavour:', \
0235             data.meas_type, data.op, data.flav)
0236         print('='*60 + '\n')
0237         if verbose:
0238             data.print_data()
0239         testsuite = unittest.TestLoader().loadTestsFromTestCase(
0240             BtagCalibConsistencyChecker)
0241         res = unittest.TextTestRunner().run(testsuite)
0242         all_res.append(not bool(res.failures))
0243     return all_res
0244 
0245 
0246 if __name__ == '__main__':
0247     if len(sys.argv) < 2:
0248         print('Need csv data file as first argument.')
0249         print('Options:')
0250         print('    --light (do not check op, sys, flav)')
0251         print('    --separate-by-op')
0252         print('    --separate-by-flav')
0253         print('    --separate-all (both of the above)')
0254         print('Exit.')
0255         exit(-1)
0256 
0257     ck_op = ck_sy = ck_fl = not '--light' in sys.argv
0258 
0259     dataLoader.separate_by_op   = '--separate-by-op'   in sys.argv
0260     dataLoader.separate_by_flav = '--separate-by-flav' in sys.argv
0261 
0262     if '--separate-all' in sys.argv:
0263         dataLoader.separate_by_op = dataLoader.separate_by_flav = True
0264 
0265     if dataLoader.separate_by_op:
0266         ck_op = False
0267     if dataLoader.separate_by_flav:
0268         ck_fl = False
0269 
0270     verbose = True
0271     if not all(run_check(sys.argv[1], ck_op, ck_sy, ck_fl)):
0272         exit(-1)
0273