Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:07

0001 #!/usr/bin/env python3
0002 # from https://github.com/nsmith-/correctionlib/blob/master/tests/test_core.py
0003 
0004 import json
0005 import math
0006 
0007 import pytest
0008 
0009 import correctionlib._core as core
0010 from correctionlib import schemav1
0011 
0012 
0013 def test_evaluator_v1():
0014     with pytest.raises(RuntimeError):
0015         cset = core.CorrectionSet.from_string("{")
0016 
0017     with pytest.raises(RuntimeError):
0018         cset = core.CorrectionSet.from_string("{}")
0019 
0020     with pytest.raises(RuntimeError):
0021         cset = core.CorrectionSet.from_string('{"schema_version": "blah"}')
0022 
0023     def wrap(*corrs):
0024         cset = schemav1.CorrectionSet(
0025             schema_version=1,
0026             corrections=list(corrs),
0027         )
0028         return core.CorrectionSet.from_string(cset.json())
0029 
0030     cset = wrap(
0031         schemav1.Correction(
0032             name="test corr",
0033             version=2,
0034             inputs=[],
0035             output=schemav1.Variable(name="a scale", type="real"),
0036             data=1.234,
0037         )
0038     )
0039     assert set(cset) == {"test corr"}
0040     sf = cset["test corr"]
0041     assert sf.version == 2
0042     assert sf.description == ""
0043 
0044     with pytest.raises(RuntimeError):
0045         sf.evaluate(0, 1.2, 35.0, 0.01)
0046 
0047     assert sf.evaluate() == 1.234
0048 
0049     cset = wrap(
0050         schemav1.Correction(
0051             name="test corr",
0052             version=2,
0053             inputs=[
0054                 schemav1.Variable(name="pt", type="real"),
0055                 schemav1.Variable(name="syst", type="string"),
0056             ],
0057             output=schemav1.Variable(name="a scale", type="real"),
0058             data=schemav1.Binning.parse_obj(
0059                 {
0060                     "nodetype": "binning",
0061                     "edges": [0, 20, 40],
0062                     "content": [
0063                         {
0064                             "nodetype": "category",
0065                             "keys": ["blah", "blah2"],
0066                             "content": [1.1, 2.2],
0067                         },
0068                         {
0069                             "nodetype": "category",
0070                             "keys": ["blah2", "blah3"],
0071                             "content": [
0072                                 1.3,
0073                                 {
0074                                     "expression": "0.25*x + exp(3.1)",
0075                                     "parser": "TFormula",
0076                                     "parameters": [0],
0077                                 },
0078                             ],
0079                         },
0080                     ],
0081                 }
0082             ),
0083         )
0084     )
0085     assert set(cset) == {"test corr"}
0086     sf = cset["test corr"]
0087     assert sf.version == 2
0088     assert sf.description == ""
0089 
0090     with pytest.raises(RuntimeError):
0091         # too many inputs
0092         sf.evaluate(0, 1.2, 35.0, 0.01)
0093 
0094     with pytest.raises(RuntimeError):
0095         # not enough inputs
0096         sf.evaluate(1.2)
0097 
0098     with pytest.raises(RuntimeError):
0099         # wrong type
0100         sf.evaluate(5)
0101 
0102     with pytest.raises(RuntimeError):
0103         # wrong type
0104         sf.evaluate("asdf")
0105 
0106     assert sf.evaluate(12.0, "blah") == 1.1
0107     # Do we need pytest.approx? Maybe not
0108     assert sf.evaluate(31.0, "blah3") == 0.25 * 31.0 + math.exp(3.1)
0109 
0110 
0111 def test_tformula():
0112     formulas = [
0113         ("23.*x", lambda x: 23.0 * x),
0114         ("23.*log(max(x, 0.1))", lambda x: 23.0 * math.log(max(x, 0.1))),
0115     ]
0116     cset = {
0117         "schema_version": 1,
0118         "corrections": [
0119             {
0120                 "name": "test",
0121                 "version": 1,
0122                 "inputs": [
0123                     {"name": "index", "type": "int"},
0124                     {"name": "x", "type": "real"},
0125                 ],
0126                 "output": {"name": "f", "type": "real"},
0127                 "data": {
0128                     "nodetype": "category",
0129                     "keys": list(range(len(formulas))),
0130                     "content": [
0131                         {"expression": expr, "parser": "TFormula", "parameters": [1]}
0132                         for expr, _ in formulas
0133                     ],
0134                 },
0135             }
0136         ],
0137     }
0138     schemav1.CorrectionSet.parse_obj(cset)
0139     corr = core.CorrectionSet.from_string(json.dumps(cset))["test"]
0140     test_values = [1.0, 32.0, -3.0, 1550.0]
0141     for i, (_, expected) in enumerate(formulas):
0142         for x in test_values:
0143             assert corr.evaluate(i, x) == expected(x)