File indexing completed on 2024-04-06 12:24:07
0001
0002
0003
0004 import json
0005 import math
0006
0007 import pytest
0008
0009 import correctionlib._core as core
0010 from correctionlib import schemav1
0011
0012
0013 def test_evaluator_v1():
0014 with pytest.raises(RuntimeError):
0015 cset = core.CorrectionSet.from_string("{")
0016
0017 with pytest.raises(RuntimeError):
0018 cset = core.CorrectionSet.from_string("{}")
0019
0020 with pytest.raises(RuntimeError):
0021 cset = core.CorrectionSet.from_string('{"schema_version": "blah"}')
0022
0023 def wrap(*corrs):
0024 cset = schemav1.CorrectionSet(
0025 schema_version=1,
0026 corrections=list(corrs),
0027 )
0028 return core.CorrectionSet.from_string(cset.json())
0029
0030 cset = wrap(
0031 schemav1.Correction(
0032 name="test corr",
0033 version=2,
0034 inputs=[],
0035 output=schemav1.Variable(name="a scale", type="real"),
0036 data=1.234,
0037 )
0038 )
0039 assert set(cset) == {"test corr"}
0040 sf = cset["test corr"]
0041 assert sf.version == 2
0042 assert sf.description == ""
0043
0044 with pytest.raises(RuntimeError):
0045 sf.evaluate(0, 1.2, 35.0, 0.01)
0046
0047 assert sf.evaluate() == 1.234
0048
0049 cset = wrap(
0050 schemav1.Correction(
0051 name="test corr",
0052 version=2,
0053 inputs=[
0054 schemav1.Variable(name="pt", type="real"),
0055 schemav1.Variable(name="syst", type="string"),
0056 ],
0057 output=schemav1.Variable(name="a scale", type="real"),
0058 data=schemav1.Binning.parse_obj(
0059 {
0060 "nodetype": "binning",
0061 "edges": [0, 20, 40],
0062 "content": [
0063 {
0064 "nodetype": "category",
0065 "keys": ["blah", "blah2"],
0066 "content": [1.1, 2.2],
0067 },
0068 {
0069 "nodetype": "category",
0070 "keys": ["blah2", "blah3"],
0071 "content": [
0072 1.3,
0073 {
0074 "expression": "0.25*x + exp(3.1)",
0075 "parser": "TFormula",
0076 "parameters": [0],
0077 },
0078 ],
0079 },
0080 ],
0081 }
0082 ),
0083 )
0084 )
0085 assert set(cset) == {"test corr"}
0086 sf = cset["test corr"]
0087 assert sf.version == 2
0088 assert sf.description == ""
0089
0090 with pytest.raises(RuntimeError):
0091
0092 sf.evaluate(0, 1.2, 35.0, 0.01)
0093
0094 with pytest.raises(RuntimeError):
0095
0096 sf.evaluate(1.2)
0097
0098 with pytest.raises(RuntimeError):
0099
0100 sf.evaluate(5)
0101
0102 with pytest.raises(RuntimeError):
0103
0104 sf.evaluate("asdf")
0105
0106 assert sf.evaluate(12.0, "blah") == 1.1
0107
0108 assert sf.evaluate(31.0, "blah3") == 0.25 * 31.0 + math.exp(3.1)
0109
0110
0111 def test_tformula():
0112 formulas = [
0113 ("23.*x", lambda x: 23.0 * x),
0114 ("23.*log(max(x, 0.1))", lambda x: 23.0 * math.log(max(x, 0.1))),
0115 ]
0116 cset = {
0117 "schema_version": 1,
0118 "corrections": [
0119 {
0120 "name": "test",
0121 "version": 1,
0122 "inputs": [
0123 {"name": "index", "type": "int"},
0124 {"name": "x", "type": "real"},
0125 ],
0126 "output": {"name": "f", "type": "real"},
0127 "data": {
0128 "nodetype": "category",
0129 "keys": list(range(len(formulas))),
0130 "content": [
0131 {"expression": expr, "parser": "TFormula", "parameters": [1]}
0132 for expr, _ in formulas
0133 ],
0134 },
0135 }
0136 ],
0137 }
0138 schemav1.CorrectionSet.parse_obj(cset)
0139 corr = core.CorrectionSet.from_string(json.dumps(cset))["test"]
0140 test_values = [1.0, 32.0, -3.0, 1550.0]
0141 for i, (_, expected) in enumerate(formulas):
0142 for x in test_values:
0143 assert corr.evaluate(i, x) == expected(x)