Warning, /RecoTracker/LSTCore/standalone/analysis/DNN/train_pT3_DNN.ipynb is written in an unsupported language. File is not indexed.
0001 {
0002 "cells": [
0003 {
0004 "cell_type": "code",
0005 "execution_count": 1,
0006 "metadata": {},
0007 "outputs": [],
0008 "source": [
0009 "import os\n",
0010 "import uproot\n",
0011 "import numpy as np\n",
0012 "\n",
0013 "def load_root_file(file_path, branches=None, print_branches=False):\n",
0014 " all_branches = {}\n",
0015 " with uproot.open(file_path) as file:\n",
0016 " tree = file[\"tree\"]\n",
0017 " # Load all ROOT branches into array if not specified\n",
0018 " if branches is None:\n",
0019 " branches = tree.keys()\n",
0020 " # Option to print the branch names\n",
0021 " if print_branches:\n",
0022 " print(\"Branches:\", tree.keys())\n",
0023 " # Each branch is added to the dictionary\n",
0024 " for branch in branches:\n",
0025 " try:\n",
0026 " all_branches[branch] = (tree[branch].array(library=\"np\"))\n",
0027 " except uproot.KeyInFileError as e:\n",
0028 " print(f\"KeyInFileError: {e}\")\n",
0029 " # Number of events in file\n",
0030 " all_branches['event'] = tree.num_entries\n",
0031 " return all_branches\n",
0032 "\n",
0033 "branches_list = [\n",
0034 " 'sim_pT3_matched',\n",
0035 " 'pT3_pt',\n",
0036 " 'pT3_isFake',\n",
0037 " 'pT3_isDuplicate',\n",
0038 " 'pT3_eta',\n",
0039 " 'pT3_phi',\n",
0040 " 'pT3_score',\n",
0041 " 'pT3_foundDuplicate',\n",
0042 " 'pT3_matched_simIdx',\n",
0043 " 'pT3_hitIdxs',\n",
0044 " 'pT3_pixelRadius',\n",
0045 " 'pT3_pixelRadiusError',\n",
0046 " 'pT3_tripletRadius',\n",
0047 " 'pT3_rPhiChiSquared',\n",
0048 " 'pT3_rPhiChiSquaredInwards',\n",
0049 " 'pT3_rzChiSquared',\n",
0050 " 'pT3_layer_binary',\n",
0051 " 'pT3_moduleType_binary'\n",
0052 "]\n",
0053 "\n",
0054 "file_path = \"pt3_500_fixed_new_2.root\"\n",
0055 "branches = load_root_file(file_path, branches_list)"
0056 ]
0057 },
0058 {
0059 "cell_type": "code",
0060 "execution_count": 2,
0061 "metadata": {},
0062 "outputs": [],
0063 "source": [
0064 "eta_max = 2.5\n",
0065 "phi_max = np.pi"
0066 ]
0067 },
0068 {
0069 "cell_type": "code",
0070 "execution_count": 3,
0071 "metadata": {},
0072 "outputs": [],
0073 "source": [
0074 "n_events = branches['event']\n",
0075 "\n",
0076 "# Concatenate the pt3 branches over all events\n",
0077 "pt3_rPhiChiSquared = np.concatenate([branches['pT3_rPhiChiSquared'][evt] for evt in range(n_events)])\n",
0078 "pt3_rPhiChiSquaredInwards = np.concatenate([branches['pT3_rPhiChiSquaredInwards'][evt] for evt in range(n_events)])\n",
0079 "pt3_rzChiSquared = np.concatenate([branches['pT3_rzChiSquared'][evt] for evt in range(n_events)])\n",
0080 "pt3_eta = np.abs(np.concatenate([branches['pT3_eta'][evt] for evt in range(n_events)]))\n",
0081 "pt3_trip_rad = np.abs(np.concatenate([branches['pT3_tripletRadius'][evt] for evt in range(n_events)]))\n",
0082 "pt3_pix_rad = np.abs(np.concatenate([branches['pT3_pixelRadius'][evt] for evt in range(n_events)]))\n",
0083 "pt3_pixRadError = np.abs(np.concatenate([branches['pT3_pixelRadiusError'][evt] for evt in range(n_events)]))\n",
0084 "\n",
0085 "# Build the features array using the helper functions\n",
0086 "features = np.array([\n",
0087 " np.log10(pt3_rPhiChiSquared),\n",
0088 " np.log10(pt3_trip_rad),\n",
0089 " np.log10(pt3_pix_rad),\n",
0090 " np.log10(pt3_pixRadError),\n",
0091 " np.log10(pt3_rzChiSquared),\n",
0092 " np.abs(pt3_eta)/eta_max\n",
0093 "])\n",
0094 "\n",
0095 "eta_list = np.array([pt3_eta])"
0096 ]
0097 },
0098 {
0099 "cell_type": "code",
0100 "execution_count": 4,
0101 "metadata": {},
0102 "outputs": [
0103 {
0104 "name": "stdout",
0105 "output_type": "stream",
0106 "text": [
0107 "Using device: cuda\n",
0108 "Initial dataset size: 229144\n",
0109 "Class distribution before downsampling - Real: 192431.0, Fake: 36713.0\n",
0110 "Class distribution after downsampling - Real: 192431.0, Fake: 36713.0\n",
0111 "Epoch [1/200], Train Loss: 0.2689, Test Loss: 0.2072\n",
0112 "Epoch [2/200], Train Loss: 0.1906, Test Loss: 0.1789\n",
0113 "Epoch [3/200], Train Loss: 0.1737, Test Loss: 0.1718\n",
0114 "Epoch [4/200], Train Loss: 0.1693, Test Loss: 0.1684\n",
0115 "Epoch [5/200], Train Loss: 0.1665, Test Loss: 0.1673\n",
0116 "Epoch [6/200], Train Loss: 0.1653, Test Loss: 0.1661\n",
0117 "Epoch [7/200], Train Loss: 0.1661, Test Loss: 0.1631\n",
0118 "Epoch [8/200], Train Loss: 0.1633, Test Loss: 0.1639\n",
0119 "Epoch [9/200], Train Loss: 0.1628, Test Loss: 0.1792\n",
0120 "Epoch [10/200], Train Loss: 0.1621, Test Loss: 0.1624\n",
0121 "Epoch [11/200], Train Loss: 0.1610, Test Loss: 0.1613\n",
0122 "Epoch [12/200], Train Loss: 0.1609, Test Loss: 0.1631\n",
0123 "Epoch [13/200], Train Loss: 0.1617, Test Loss: 0.1628\n",
0124 "Epoch [14/200], Train Loss: 0.1607, Test Loss: 0.1597\n",
0125 "Epoch [15/200], Train Loss: 0.1595, Test Loss: 0.1610\n",
0126 "Epoch [16/200], Train Loss: 0.1584, Test Loss: 0.1641\n",
0127 "Epoch [17/200], Train Loss: 0.1576, Test Loss: 0.1582\n",
0128 "Epoch [18/200], Train Loss: 0.1581, Test Loss: 0.1578\n",
0129 "Epoch [19/200], Train Loss: 0.1581, Test Loss: 0.1707\n",
0130 "Epoch [20/200], Train Loss: 0.1589, Test Loss: 0.1618\n",
0131 "Epoch [21/200], Train Loss: 0.1577, Test Loss: 0.1579\n",
0132 "Epoch [22/200], Train Loss: 0.1571, Test Loss: 0.1579\n",
0133 "Epoch [23/200], Train Loss: 0.1563, Test Loss: 0.1609\n",
0134 "Epoch [24/200], Train Loss: 0.1558, Test Loss: 0.1552\n",
0135 "Epoch [25/200], Train Loss: 0.1579, Test Loss: 0.1723\n",
0136 "Epoch [26/200], Train Loss: 0.1571, Test Loss: 0.1547\n",
0137 "Epoch [27/200], Train Loss: 0.1547, Test Loss: 0.1582\n",
0138 "Epoch [28/200], Train Loss: 0.1549, Test Loss: 0.1542\n",
0139 "Epoch [29/200], Train Loss: 0.1536, Test Loss: 0.1556\n",
0140 "Epoch [30/200], Train Loss: 0.1528, Test Loss: 0.1554\n",
0141 "Epoch [31/200], Train Loss: 0.1523, Test Loss: 0.1540\n",
0142 "Epoch [32/200], Train Loss: 0.1527, Test Loss: 0.1526\n",
0143 "Epoch [33/200], Train Loss: 0.1526, Test Loss: 0.1558\n",
0144 "Epoch [34/200], Train Loss: 0.1533, Test Loss: 0.1523\n",
0145 "Epoch [35/200], Train Loss: 0.1516, Test Loss: 0.1517\n",
0146 "Epoch [36/200], Train Loss: 0.1510, Test Loss: 0.1535\n",
0147 "Epoch [37/200], Train Loss: 0.1509, Test Loss: 0.1519\n",
0148 "Epoch [38/200], Train Loss: 0.1498, Test Loss: 0.1517\n",
0149 "Epoch [39/200], Train Loss: 0.1513, Test Loss: 0.1525\n",
0150 "Epoch [40/200], Train Loss: 0.1509, Test Loss: 0.1496\n",
0151 "Epoch [41/200], Train Loss: 0.1491, Test Loss: 0.1498\n",
0152 "Epoch [42/200], Train Loss: 0.1497, Test Loss: 0.1517\n",
0153 "Epoch [43/200], Train Loss: 0.1484, Test Loss: 0.1499\n",
0154 "Epoch [44/200], Train Loss: 0.1475, Test Loss: 0.1513\n",
0155 "Epoch [45/200], Train Loss: 0.1470, Test Loss: 0.1488\n",
0156 "Epoch [46/200], Train Loss: 0.1512, Test Loss: 0.2419\n",
0157 "Epoch [47/200], Train Loss: 0.1649, Test Loss: 0.1603\n",
0158 "Epoch [48/200], Train Loss: 0.1575, Test Loss: 0.1575\n",
0159 "Epoch [49/200], Train Loss: 0.1561, Test Loss: 0.1585\n",
0160 "Epoch [50/200], Train Loss: 0.1551, Test Loss: 0.1543\n",
0161 "Epoch [51/200], Train Loss: 0.1542, Test Loss: 0.1568\n",
0162 "Epoch [52/200], Train Loss: 0.1529, Test Loss: 0.1546\n",
0163 "Epoch [53/200], Train Loss: 0.1536, Test Loss: 0.1562\n",
0164 "Epoch [54/200], Train Loss: 0.1522, Test Loss: 0.1531\n",
0165 "Epoch [55/200], Train Loss: 0.1514, Test Loss: 0.1563\n",
0166 "Epoch [56/200], Train Loss: 0.1505, Test Loss: 0.1525\n",
0167 "Epoch [57/200], Train Loss: 0.1503, Test Loss: 0.1511\n",
0168 "Epoch [58/200], Train Loss: 0.1491, Test Loss: 0.1507\n",
0169 "Epoch [59/200], Train Loss: 0.1495, Test Loss: 0.1497\n",
0170 "Epoch [60/200], Train Loss: 0.1493, Test Loss: 0.1484\n",
0171 "Epoch [61/200], Train Loss: 0.1494, Test Loss: 0.1530\n",
0172 "Epoch [62/200], Train Loss: 0.1481, Test Loss: 0.1469\n",
0173 "Epoch [63/200], Train Loss: 0.1460, Test Loss: 0.1493\n",
0174 "Epoch [64/200], Train Loss: 0.1462, Test Loss: 0.1473\n",
0175 "Epoch [65/200], Train Loss: 0.1460, Test Loss: 0.1468\n",
0176 "Epoch [66/200], Train Loss: 0.1450, Test Loss: 0.1450\n",
0177 "Epoch [67/200], Train Loss: 0.1444, Test Loss: 0.1503\n",
0178 "Epoch [68/200], Train Loss: 0.1473, Test Loss: 0.1676\n",
0179 "Epoch [69/200], Train Loss: 0.1483, Test Loss: 0.1446\n",
0180 "Epoch [70/200], Train Loss: 0.1445, Test Loss: 0.1468\n",
0181 "Epoch [71/200], Train Loss: 0.1438, Test Loss: 0.1471\n",
0182 "Epoch [72/200], Train Loss: 0.1439, Test Loss: 0.1545\n",
0183 "Epoch [73/200], Train Loss: 0.1427, Test Loss: 0.1447\n",
0184 "Epoch [74/200], Train Loss: 0.1432, Test Loss: 0.1433\n",
0185 "Epoch [75/200], Train Loss: 0.1415, Test Loss: 0.1472\n",
0186 "Epoch [76/200], Train Loss: 0.1418, Test Loss: 0.1480\n",
0187 "Epoch [77/200], Train Loss: 0.1411, Test Loss: 0.1423\n",
0188 "Epoch [78/200], Train Loss: 0.1420, Test Loss: 0.1507\n",
0189 "Epoch [79/200], Train Loss: 0.1409, Test Loss: 0.1429\n",
0190 "Epoch [80/200], Train Loss: 0.1404, Test Loss: 0.1422\n",
0191 "Epoch [81/200], Train Loss: 0.1413, Test Loss: 0.1446\n",
0192 "Epoch [82/200], Train Loss: 0.1417, Test Loss: 0.2182\n",
0193 "Epoch [83/200], Train Loss: 0.1450, Test Loss: 0.1410\n",
0194 "Epoch [84/200], Train Loss: 0.1415, Test Loss: 0.1561\n",
0195 "Epoch [85/200], Train Loss: 0.1441, Test Loss: 0.1430\n",
0196 "Epoch [86/200], Train Loss: 0.1410, Test Loss: 0.1409\n",
0197 "Epoch [87/200], Train Loss: 0.1392, Test Loss: 0.1431\n",
0198 "Epoch [88/200], Train Loss: 0.1410, Test Loss: 0.1445\n",
0199 "Epoch [89/200], Train Loss: 0.1387, Test Loss: 0.1401\n",
0200 "Epoch [90/200], Train Loss: 0.1394, Test Loss: 0.1408\n",
0201 "Epoch [91/200], Train Loss: 0.1389, Test Loss: 0.1415\n",
0202 "Epoch [92/200], Train Loss: 0.1393, Test Loss: 0.1410\n",
0203 "Epoch [93/200], Train Loss: 0.1379, Test Loss: 0.1395\n",
0204 "Epoch [94/200], Train Loss: 0.1381, Test Loss: 0.1409\n",
0205 "Epoch [95/200], Train Loss: 0.1404, Test Loss: 0.1427\n",
0206 "Epoch [96/200], Train Loss: 0.1392, Test Loss: 0.1404\n",
0207 "Epoch [97/200], Train Loss: 0.1387, Test Loss: 0.1390\n",
0208 "Epoch [98/200], Train Loss: 0.1384, Test Loss: 0.1421\n",
0209 "Epoch [99/200], Train Loss: 0.1383, Test Loss: 0.1411\n",
0210 "Epoch [100/200], Train Loss: 0.1374, Test Loss: 0.1375\n",
0211 "Epoch [101/200], Train Loss: 0.1365, Test Loss: 0.1389\n",
0212 "Epoch [102/200], Train Loss: 0.1369, Test Loss: 0.1383\n",
0213 "Epoch [103/200], Train Loss: 0.1368, Test Loss: 0.1435\n",
0214 "Epoch [104/200], Train Loss: 0.1369, Test Loss: 0.1376\n",
0215 "Epoch [105/200], Train Loss: 0.1370, Test Loss: 0.1374\n",
0216 "Epoch [106/200], Train Loss: 0.1364, Test Loss: 0.1435\n",
0217 "Epoch [107/200], Train Loss: 0.1360, Test Loss: 0.1390\n",
0218 "Epoch [108/200], Train Loss: 0.1368, Test Loss: 0.1376\n",
0219 "Epoch [109/200], Train Loss: 0.1362, Test Loss: 0.1407\n",
0220 "Epoch [110/200], Train Loss: 0.1360, Test Loss: 0.1387\n",
0221 "Epoch [111/200], Train Loss: 0.1354, Test Loss: 0.1364\n",
0222 "Epoch [112/200], Train Loss: 0.1361, Test Loss: 0.1387\n",
0223 "Epoch [113/200], Train Loss: 0.1362, Test Loss: 0.1387\n",
0224 "Epoch [114/200], Train Loss: 0.1358, Test Loss: 0.1378\n",
0225 "Epoch [115/200], Train Loss: 0.1349, Test Loss: 0.1368\n",
0226 "Epoch [116/200], Train Loss: 0.1354, Test Loss: 0.1367\n",
0227 "Epoch [117/200], Train Loss: 0.1361, Test Loss: 0.1434\n",
0228 "Epoch [118/200], Train Loss: 0.1364, Test Loss: 0.1370\n",
0229 "Epoch [119/200], Train Loss: 0.1375, Test Loss: 0.1418\n",
0230 "Epoch [120/200], Train Loss: 0.1373, Test Loss: 0.1389\n",
0231 "Epoch [121/200], Train Loss: 0.1353, Test Loss: 0.1392\n",
0232 "Epoch [122/200], Train Loss: 0.1354, Test Loss: 0.1377\n",
0233 "Epoch [123/200], Train Loss: 0.1365, Test Loss: 0.1422\n",
0234 "Epoch [124/200], Train Loss: 0.1359, Test Loss: 0.1354\n",
0235 "Epoch [125/200], Train Loss: 0.1343, Test Loss: 0.1370\n",
0236 "Epoch [126/200], Train Loss: 0.1340, Test Loss: 0.1358\n",
0237 "Epoch [127/200], Train Loss: 0.1347, Test Loss: 0.1373\n",
0238 "Epoch [128/200], Train Loss: 0.1352, Test Loss: 0.1367\n",
0239 "Epoch [129/200], Train Loss: 0.1351, Test Loss: 0.1360\n",
0240 "Epoch [130/200], Train Loss: 0.1344, Test Loss: 0.1362\n",
0241 "Epoch [131/200], Train Loss: 0.1360, Test Loss: 0.1722\n",
0242 "Epoch [132/200], Train Loss: 0.1371, Test Loss: 0.1358\n",
0243 "Epoch [133/200], Train Loss: 0.1338, Test Loss: 0.1354\n",
0244 "Epoch [134/200], Train Loss: 0.1335, Test Loss: 0.1359\n",
0245 "Epoch [135/200], Train Loss: 0.1354, Test Loss: 0.1755\n",
0246 "Epoch [136/200], Train Loss: 0.1396, Test Loss: 0.1357\n",
0247 "Epoch [137/200], Train Loss: 0.1346, Test Loss: 0.1348\n",
0248 "Epoch [138/200], Train Loss: 0.1334, Test Loss: 0.1362\n",
0249 "Epoch [139/200], Train Loss: 0.1340, Test Loss: 0.1376\n",
0250 "Epoch [140/200], Train Loss: 0.1339, Test Loss: 0.1346\n",
0251 "Epoch [141/200], Train Loss: 0.1336, Test Loss: 0.1388\n",
0252 "Epoch [142/200], Train Loss: 0.1350, Test Loss: 0.1368\n",
0253 "Epoch [143/200], Train Loss: 0.1359, Test Loss: 0.1358\n",
0254 "Epoch [144/200], Train Loss: 0.1348, Test Loss: 0.1392\n",
0255 "Epoch [145/200], Train Loss: 0.1339, Test Loss: 0.1440\n",
0256 "Epoch [146/200], Train Loss: 0.1344, Test Loss: 0.1351\n",
0257 "Epoch [147/200], Train Loss: 0.1330, Test Loss: 0.1362\n",
0258 "Epoch [148/200], Train Loss: 0.1339, Test Loss: 0.1373\n",
0259 "Epoch [149/200], Train Loss: 0.1343, Test Loss: 0.1495\n",
0260 "Epoch [150/200], Train Loss: 0.1360, Test Loss: 0.1357\n",
0261 "Epoch [151/200], Train Loss: 0.1328, Test Loss: 0.1360\n",
0262 "Epoch [152/200], Train Loss: 0.1329, Test Loss: 0.1373\n",
0263 "Epoch [153/200], Train Loss: 0.1331, Test Loss: 0.1389\n",
0264 "Epoch [154/200], Train Loss: 0.1346, Test Loss: 0.1431\n",
0265 "Epoch [155/200], Train Loss: 0.1355, Test Loss: 0.1465\n",
0266 "Epoch [156/200], Train Loss: 0.1343, Test Loss: 0.1359\n",
0267 "Epoch [157/200], Train Loss: 0.1326, Test Loss: 0.1340\n",
0268 "Epoch [158/200], Train Loss: 0.1327, Test Loss: 0.1340\n",
0269 "Epoch [159/200], Train Loss: 0.1342, Test Loss: 0.1459\n",
0270 "Epoch [160/200], Train Loss: 0.1334, Test Loss: 0.1347\n",
0271 "Epoch [161/200], Train Loss: 0.1332, Test Loss: 0.1379\n",
0272 "Epoch [162/200], Train Loss: 0.1334, Test Loss: 0.1357\n",
0273 "Epoch [163/200], Train Loss: 0.1329, Test Loss: 0.1396\n",
0274 "Epoch [164/200], Train Loss: 0.1341, Test Loss: 0.1364\n",
0275 "Epoch [165/200], Train Loss: 0.1334, Test Loss: 0.1348\n",
0276 "Epoch [166/200], Train Loss: 0.1328, Test Loss: 0.1346\n",
0277 "Epoch [167/200], Train Loss: 0.1327, Test Loss: 0.1343\n",
0278 "Epoch [168/200], Train Loss: 0.1343, Test Loss: 0.1438\n",
0279 "Epoch [169/200], Train Loss: 0.1343, Test Loss: 0.1335\n",
0280 "Epoch [170/200], Train Loss: 0.1325, Test Loss: 0.1337\n",
0281 "Epoch [171/200], Train Loss: 0.1322, Test Loss: 0.1363\n",
0282 "Epoch [172/200], Train Loss: 0.1334, Test Loss: 0.1340\n",
0283 "Epoch [173/200], Train Loss: 0.1336, Test Loss: 0.1366\n",
0284 "Epoch [174/200], Train Loss: 0.1326, Test Loss: 0.1357\n",
0285 "Epoch [175/200], Train Loss: 0.1324, Test Loss: 0.1390\n",
0286 "Epoch [176/200], Train Loss: 0.1331, Test Loss: 0.1346\n",
0287 "Epoch [177/200], Train Loss: 0.1327, Test Loss: 0.1356\n",
0288 "Epoch [178/200], Train Loss: 0.1327, Test Loss: 0.1337\n",
0289 "Epoch [179/200], Train Loss: 0.1326, Test Loss: 0.1346\n",
0290 "Epoch [180/200], Train Loss: 0.1329, Test Loss: 0.1360\n",
0291 "Epoch [181/200], Train Loss: 0.1317, Test Loss: 0.1336\n",
0292 "Epoch [182/200], Train Loss: 0.1315, Test Loss: 0.1345\n",
0293 "Epoch [183/200], Train Loss: 0.1334, Test Loss: 0.1330\n",
0294 "Epoch [184/200], Train Loss: 0.1321, Test Loss: 0.1344\n",
0295 "Epoch [185/200], Train Loss: 0.1321, Test Loss: 0.1395\n",
0296 "Epoch [186/200], Train Loss: 0.1332, Test Loss: 0.1420\n",
0297 "Epoch [187/200], Train Loss: 0.1329, Test Loss: 0.1348\n",
0298 "Epoch [188/200], Train Loss: 0.1313, Test Loss: 0.1343\n",
0299 "Epoch [189/200], Train Loss: 0.1312, Test Loss: 0.1338\n",
0300 "Epoch [190/200], Train Loss: 0.1328, Test Loss: 0.1375\n",
0301 "Epoch [191/200], Train Loss: 0.1319, Test Loss: 0.1368\n",
0302 "Epoch [192/200], Train Loss: 0.1322, Test Loss: 0.1353\n",
0303 "Epoch [193/200], Train Loss: 0.1323, Test Loss: 0.1349\n",
0304 "Epoch [194/200], Train Loss: 0.1318, Test Loss: 0.1340\n",
0305 "Epoch [195/200], Train Loss: 0.1317, Test Loss: 0.1378\n",
0306 "Epoch [196/200], Train Loss: 0.1334, Test Loss: 0.1360\n",
0307 "Epoch [197/200], Train Loss: 0.1337, Test Loss: 0.1331\n",
0308 "Epoch [198/200], Train Loss: 0.1335, Test Loss: 0.1334\n",
0309 "Epoch [199/200], Train Loss: 0.1331, Test Loss: 0.1359\n",
0310 "Epoch [200/200], Train Loss: 0.1324, Test Loss: 0.1337\n"
0311 ]
0312 }
0313 ],
0314 "source": [
0315 "import torch\n",
0316 "import torch.nn as nn\n",
0317 "from torch.optim import Adam\n",
0318 "from torch.utils.data import DataLoader, TensorDataset, random_split\n",
0319 "import numpy as np\n",
0320 "\n",
0321 "# ------------------ Preprocessing ------------------\n",
0322 "input_features_numpy = np.stack(features, axis=-1)\n",
0323 "mask = ~np.isnan(input_features_numpy) & ~np.isinf(input_features_numpy)\n",
0324 "filtered_input_features_numpy = input_features_numpy[np.all(mask, axis=1)]\n",
0325 "t3_isFake_filtered = 1 - (np.concatenate(branches['pT3_isFake']))[np.all(mask, axis=1)]\n",
0326 "\n",
0327 "# Convert to PyTorch tensors.\n",
0328 "input_features_tensor = torch.tensor(filtered_input_features_numpy, dtype=torch.float32)\n",
0329 "labels_tensor = torch.tensor(t3_isFake_filtered, dtype=torch.float32).unsqueeze(1)\n",
0330 "\n",
0331 "# ------------------ Device Setup ------------------\n",
0332 "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
0333 "print(f\"Using device: {device}\")\n",
0334 "\n",
0335 "# ------------------ Neural Network ------------------\n",
0336 "class BinaryClassificationNeuralNetwork(nn.Module):\n",
0337 " def __init__(self, input_dim):\n",
0338 " super(BinaryClassificationNeuralNetwork, self).__init__()\n",
0339 " self.layer1 = nn.Linear(input_dim, 32)\n",
0340 " self.layer2 = nn.Linear(32, 32)\n",
0341 " self.output_layer = nn.Linear(32, 1) # Single output for binary classification\n",
0342 "\n",
0343 " def forward(self, x):\n",
0344 " x = self.layer1(x)\n",
0345 " x = nn.ReLU()(x)\n",
0346 " x = self.layer2(x)\n",
0347 " x = nn.ReLU()(x)\n",
0348 " x = self.output_layer(x)\n",
0349 " return torch.sigmoid(x) # Sigmoid activation for output between 0 and 1\n",
0350 "\n",
0351 "# ------------------ Loss Function ------------------\n",
0352 "class WeightedBinaryCrossEntropyLoss(nn.Module):\n",
0353 " def __init__(self):\n",
0354 " super(WeightedBinaryCrossEntropyLoss, self).__init__()\n",
0355 "\n",
0356 " def forward(self, outputs, targets, weights):\n",
0357 " eps = 1e-7\n",
0358 " loss = -(targets * torch.log(outputs + eps) + (1 - targets) * torch.log(1 - outputs + eps))\n",
0359 " weighted_loss = loss * weights\n",
0360 " return weighted_loss.mean()\n",
0361 "\n",
0362 "# ------------------ Class Weight Calculation ------------------\n",
0363 "def calculate_binary_class_weights(labels):\n",
0364 " total_samples = len(labels)\n",
0365 " count_positive = labels.sum().item()\n",
0366 " count_negative = total_samples - count_positive\n",
0367 " weight_positive = total_samples / (2 * count_positive) if count_positive > 0 else 1.0\n",
0368 " weight_negative = total_samples / (2 * count_negative) if count_negative > 0 else 1.0\n",
0369 " \n",
0370 " sample_weights = torch.zeros(total_samples)\n",
0371 " for i in range(total_samples):\n",
0372 " if labels[i] == 1:\n",
0373 " sample_weights[i] = weight_positive\n",
0374 " else:\n",
0375 " sample_weights[i] = weight_negative\n",
0376 " return sample_weights\n",
0377 "\n",
0378 "# ------------------ Data Preparation ------------------\n",
0379 "print(f\"Initial dataset size: {len(labels_tensor)}\")\n",
0380 "\n",
0381 "# Remove any rows with NaN in the input features (if any remain).\n",
0382 "nan_mask = torch.isnan(input_features_tensor).any(dim=1)\n",
0383 "filtered_inputs = input_features_tensor[~nan_mask]\n",
0384 "filtered_labels = labels_tensor[~nan_mask]\n",
0385 "\n",
0386 "# Print class distribution before downsampling.\n",
0387 "num_real = filtered_labels.sum().item() # label = 1 means real\n",
0388 "num_fake = len(filtered_labels) - num_real # label = 0 means fake\n",
0389 "print(f\"Class distribution before downsampling - Real: {num_real}, Fake: {num_fake}\")\n",
0390 "\n",
0391 "# Option to downsample the majority class.\n",
0392 "downsample_classes = False\n",
0393 "if downsample_classes:\n",
0394 " downsample_ratios = {1: 1.0, 0: 1.0}\n",
0395 " indices_list = []\n",
0396 "\n",
0397 " # Process real class (label 1).\n",
0398 " real_mask = (filtered_labels.squeeze() == 1)\n",
0399 " real_indices = torch.nonzero(real_mask).squeeze()\n",
0400 " num_real = real_indices.numel()\n",
0401 " num_real_to_sample = int(num_real * downsample_ratios[1])\n",
0402 " if num_real_to_sample < 1 and num_real > 0:\n",
0403 " num_real_to_sample = 1\n",
0404 " real_indices_shuffled = real_indices[torch.randperm(num_real)]\n",
0405 " sampled_real_indices = real_indices_shuffled[:num_real_to_sample]\n",
0406 " indices_list.append(sampled_real_indices)\n",
0407 "\n",
0408 " # Process fake class (label 0).\n",
0409 " fake_mask = (filtered_labels.squeeze() == 0)\n",
0410 " fake_indices = torch.nonzero(fake_mask).squeeze()\n",
0411 " num_fake = fake_indices.numel()\n",
0412 " num_fake_to_sample = int(num_fake * downsample_ratios[0])\n",
0413 " if num_fake_to_sample < 1 and num_fake > 0:\n",
0414 " num_fake_to_sample = 1\n",
0415 " fake_indices_shuffled = fake_indices[torch.randperm(num_fake)]\n",
0416 " sampled_fake_indices = fake_indices_shuffled[:num_fake_to_sample]\n",
0417 " indices_list.append(sampled_fake_indices)\n",
0418 "\n",
0419 " # Combine indices from both classes.\n",
0420 " selected_indices = torch.cat(indices_list)\n",
0421 " filtered_inputs = filtered_inputs[selected_indices]\n",
0422 " filtered_labels = filtered_labels[selected_indices]\n",
0423 "\n",
0424 "# Print class distribution after downsampling.\n",
0425 "num_real_after = filtered_labels.sum().item()\n",
0426 "num_fake_after = len(filtered_labels) - num_real_after\n",
0427 "print(f\"Class distribution after downsampling - Real: {num_real_after}, Fake: {num_fake_after}\")\n",
0428 "\n",
0429 "# Calculate sample weights after downsampling.\n",
0430 "sample_weights = calculate_binary_class_weights(filtered_labels)\n",
0431 "filtered_weights = sample_weights\n",
0432 "\n",
0433 "# Create the dataset.\n",
0434 "dataset = TensorDataset(filtered_inputs, filtered_labels, filtered_weights)\n",
0435 "\n",
0436 "# Split into train and test sets.\n",
0437 "train_size = int(0.8 * len(dataset))\n",
0438 "test_size = len(dataset) - train_size\n",
0439 "train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n",
0440 "\n",
0441 "# Create data loaders.\n",
0442 "train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=10, pin_memory=True)\n",
0443 "test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=10, pin_memory=True)\n",
0444 "\n",
0445 "# ------------------ Model, Loss, and Optimizer ------------------\n",
0446 "input_dim = filtered_inputs.shape[1]\n",
0447 "model = BinaryClassificationNeuralNetwork(input_dim).to(device)\n",
0448 "loss_function = WeightedBinaryCrossEntropyLoss()\n",
0449 "optimizer = Adam(model.parameters(), lr=0.0025)\n",
0450 "\n",
0451 "def evaluate_loss(loader):\n",
0452 " model.eval()\n",
0453 " total_loss = 0\n",
0454 " num_batches = 0\n",
0455 " with torch.no_grad():\n",
0456 " for inputs, targets, weights in loader:\n",
0457 " inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)\n",
0458 " outputs = model(inputs)\n",
0459 " loss = loss_function(outputs, targets, weights)\n",
0460 " total_loss += loss.item()\n",
0461 " num_batches += 1\n",
0462 " return total_loss / num_batches\n",
0463 "\n",
0464 "# ------------------ Training Loop ------------------\n",
0465 "num_epochs = 200\n",
0466 "train_loss_log = []\n",
0467 "test_loss_log = []\n",
0468 "\n",
0469 "for epoch in range(num_epochs):\n",
0470 " model.train()\n",
0471 " epoch_loss = 0\n",
0472 " num_batches = 0\n",
0473 "\n",
0474 " for inputs, targets, weights in train_loader:\n",
0475 " inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)\n",
0476 " outputs = model(inputs)\n",
0477 " loss = loss_function(outputs, targets, weights)\n",
0478 " epoch_loss += loss.item()\n",
0479 " num_batches += 1\n",
0480 " \n",
0481 " optimizer.zero_grad()\n",
0482 " loss.backward()\n",
0483 " optimizer.step()\n",
0484 " \n",
0485 " train_loss = epoch_loss / num_batches\n",
0486 " test_loss = evaluate_loss(test_loader)\n",
0487 " train_loss_log.append(train_loss)\n",
0488 " test_loss_log.append(test_loss)\n",
0489 " \n",
0490 " print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')"
0491 ]
0492 },
0493 {
0494 "cell_type": "code",
0495 "execution_count": 5,
0496 "metadata": {},
0497 "outputs": [
0498 {
0499 "name": "stdout",
0500 "output_type": "stream",
0501 "text": [
0502 "Baseline accuracy: 0.7410224080085754\n",
0503 "Feature importances:\n",
0504 "Feature 2 importance: 0.1705\n",
0505 "Feature 1 importance: 0.1614\n",
0506 "Feature 3 importance: 0.0985\n",
0507 "Feature 5 importance: 0.0871\n",
0508 "Feature 4 importance: 0.0293\n",
0509 "Feature 0 importance: 0.0219\n"
0510 ]
0511 }
0512 ],
0513 "source": [
0514 "from sklearn.metrics import accuracy_score\n",
0515 "\n",
0516 "# Convert tensors to numpy for simplicity in permutation\n",
0517 "input_features_np = input_features_tensor.numpy()\n",
0518 "labels_np = labels_tensor.numpy()\n",
0519 "\n",
0520 "def model_accuracy(features, labels, model):\n",
0521 " # Move the model to CPU for evaluation\n",
0522 " model.cpu()\n",
0523 " model.eval() # Set to evaluation mode\n",
0524 " with torch.no_grad():\n",
0525 " # Ensure features and labels are on CPU\n",
0526 " inputs = features.to('cpu')\n",
0527 " labels = labels.to('cpu')\n",
0528 " outputs = model(inputs)\n",
0529 " predicted = (outputs.squeeze() > 0.5).float() # Update threshold as necessary\n",
0530 " accuracy = (predicted == labels).float().mean().item()\n",
0531 " return accuracy\n",
0532 "\n",
0533 "# Use the original input_features_tensor and labels_tensor directly\n",
0534 "baseline_accuracy = model_accuracy(input_features_tensor, labels_tensor, model)\n",
0535 "print(f\"Baseline accuracy: {baseline_accuracy}\")\n",
0536 "\n",
0537 "# Initialize an array to store feature importances\n",
0538 "feature_importances = np.zeros(input_features_tensor.shape[1])\n",
0539 "\n",
0540 "# Permute each feature and calculate the drop in accuracy\n",
0541 "for i in range(input_features_tensor.shape[1]):\n",
0542 " permuted_features = input_features_tensor.clone()\n",
0543 " permuted_features[:, i] = permuted_features[torch.randperm(permuted_features.size(0)), i] # Permute feature\n",
0544 "\n",
0545 " permuted_accuracy = model_accuracy(permuted_features, labels_tensor, model)\n",
0546 " feature_importances[i] = baseline_accuracy - permuted_accuracy\n",
0547 "\n",
0548 "# Ranking features by importance\n",
0549 "important_features_indices = np.argsort(feature_importances)[::-1] # Indices of features in descending importance\n",
0550 "important_features_scores = np.sort(feature_importances)[::-1] # Importance scores in descending order\n",
0551 "\n",
0552 "print(\"Feature importances:\")\n",
0553 "for idx, score in zip(important_features_indices, important_features_scores):\n",
0554 " print(f\"Feature {idx} importance: {score:.4f}\")"
0555 ]
0556 },
0557 {
0558 "cell_type": "code",
0559 "execution_count": 6,
0560 "metadata": {},
0561 "outputs": [
0562 {
0563 "name": "stdout",
0564 "output_type": "stream",
0565 "text": [
0566 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer1[32] = {\n",
0567 "-1.0316417f, -0.5292138f, -1.3838978f, -0.7517025f, -0.8418103f, 0.8087351f, -0.1664009f, 0.2809185f, 1.3140178f, 0.7565588f, 0.0890440f, -0.4908848f, 0.4201532f, -0.4334770f, 0.5002150f, -0.5591785f, 1.2298888f, 0.0346711f, -0.4166603f, -0.0064792f, -0.2969901f, -0.3028315f, 0.0721094f, 0.2584246f, -0.2035742f, -0.2888707f, -0.1322349f, 0.5589037f, 0.4285649f, 0.1511498f, 0.1774099f, -0.9249431f };\n",
0568 "\n",
0569 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer1[6][32] = {\n",
0570 "{ -0.2303199f, -0.0356287f, 0.1646899f, -0.1455843f, -0.0290584f, 0.2789465f, 0.0559607f, -0.2009920f, -0.0174202f, -0.0237667f, 0.1854273f, -0.0269862f, -0.0826555f, 0.7338054f, 0.0937388f, -0.8827202f, -0.0182344f, 0.0352709f, -0.1331034f, -0.0432344f, 0.0208153f, -0.1124316f, 0.1652907f, -0.1187222f, -0.0506563f, 0.2842667f, -0.7510378f, 0.0503595f, -0.0329987f, -0.8018139f, -0.0368841f, -0.0285996f },\n",
0571 "{ 0.7176192f, -0.1105556f, 0.8051441f, -0.4914712f, -0.0188390f, -0.2385275f, -0.2936262f, 0.3048217f, -0.2883293f, -0.3593645f, -0.4168207f, -0.0530377f, 1.0154706f, -0.5427814f, -0.9457486f, -0.4481009f, -0.4298896f, 0.0229893f, 0.0285739f, 0.1118112f, -0.2768153f, -0.6592747f, 0.5621189f, 0.7928397f, -0.8139476f, -0.1173517f, -0.2300215f, -0.4201365f, 0.2074883f, 0.3603764f, -0.3421177f, 0.1743037f },\n",
0572 "{ 0.0283383f, 0.3209993f, -0.0528482f, 0.9225417f, 0.4743758f, -0.3381493f, -0.1023181f, 0.0761002f, -0.4110529f, 0.0935502f, 0.4090216f, 0.1871852f, -0.7008697f, 0.3572226f, 0.6193281f, -0.3993059f, 0.4011126f, -0.3016730f, 0.1985772f, 0.1066906f, -0.1462973f, 1.0643306f, -0.6703950f, -0.7412036f, 0.8931519f, -0.3644519f, -0.0444721f, 0.7265502f, -0.2158535f, -0.7507963f, -0.2511281f, -0.2659404f },\n",
0573 "{ 0.6280951f, -0.1979905f, 0.2835749f, -0.2547665f, 0.2009394f, 0.5826684f, 0.0070924f, 0.7995070f, -0.1270987f, 0.2988422f, 1.2406983f, 0.5875431f, 0.2251529f, -0.5355389f, -0.2763741f, 0.0566486f, 0.8032280f, 0.1221172f, 0.0441896f, 1.3281344f, 0.0374645f, 0.1209982f, -0.7381684f, -0.0807755f, -0.4638562f, -0.2137405f, -0.1491151f, 0.3022155f, -0.1751741f, -0.0065741f, 0.1841483f, -0.0963122f },\n",
0574 "{ 0.0176033f, -0.1465057f, 0.0613830f, 0.1969667f, 0.3062224f, -0.2499068f, -1.2580094f, -0.1081307f, -0.3870914f, -0.2593474f, 0.4072658f, 0.0381806f, -0.3507286f, 0.3829739f, -0.1012198f, -0.6614094f, 0.0203835f, -0.2897506f, -0.4447211f, -0.4430839f, -0.0040086f, -0.0372864f, -0.4039490f, 0.4331785f, 0.2907649f, -0.1092023f, 0.0977807f, -0.6611776f, 0.7277890f, 0.5314985f, -0.0860426f, 0.0131469f },\n",
0575 "{ 0.0536266f, 0.2549676f, -0.3011957f, -0.4934275f, 0.6024259f, 1.5041729f, 1.3324199f, -0.7268062f, 0.5070686f, 1.0215880f, 0.7595061f, 1.1927116f, -0.6326223f, 0.4896784f, 0.1556831f, 0.4490171f, -1.1741296f, -0.0409376f, 0.3861921f, 0.4442228f, 0.0257449f, -0.3155618f, 0.0957184f, -0.0695736f, 0.1083985f, 0.1015484f, 0.8968495f, -0.1153277f, -0.5456764f, -0.1840984f, 0.3110001f, 1.6959499f },\n",
0576 "};\n",
0577 "\n",
0578 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer2[32] = {\n",
0579 "-0.1015397f, -0.2844371f, 0.3264084f, 0.1813205f, -0.5612066f, 0.0906685f, 0.0845674f, 0.4616135f, -0.2177648f, -0.1652546f, 0.4002015f, -0.0791563f, 0.2383104f, 0.4796737f, 0.4520915f, -0.1967489f, 0.3534851f, 0.5968352f, 0.5477327f, 0.5137501f, 0.3921396f, -0.3068429f, 0.3759635f, 0.4266470f, -0.0625485f, 0.1195836f, 0.3834727f, -0.1557929f, 0.2742889f, -0.3761551f, 0.1094945f, 0.0651921f };\n",
0580 "\n",
0581 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer2[32][32] = {\n",
0582 "{ 0.0398679f, -0.3098788f, 0.6160386f, 0.1663042f, 0.0616995f, -0.5468591f, 0.0629986f, 0.9667135f, 0.1226420f, -0.1594982f, 0.8811188f, -0.0122212f, -0.1866431f, -0.0127864f, 0.8248008f, 0.0109573f, 0.6348554f, 0.7169302f, -0.4604602f, 0.7598190f, 0.4718021f, -0.0390644f, 0.8688850f, 0.1144402f, 0.2909767f, -0.0373881f, 0.7398790f, -0.1475801f, 0.1209616f, -0.2457073f, 1.5780616f, -0.9071835f },\n",
0583 "{ -0.1679361f, -0.1377337f, -0.0434273f, -0.0693874f, 0.1072921f, -0.0385521f, -0.1659795f, -0.0850864f, -0.0472408f, 0.0961263f, -0.0083223f, 0.1163423f, -0.4766918f, -0.2465681f, -0.0924990f, -0.1210379f, -0.1053123f, -0.5663610f, -0.0639742f, -0.2473114f, -0.1664381f, -0.2532910f, 0.2278630f, -0.1883525f, 0.1034932f, -0.1848536f, -0.1000912f, -0.0118117f, -0.6502397f, -0.2686550f, -0.0506822f, 0.0806257f },\n",
0584 "{ -0.9741716f, 2.8353660f, 0.3112126f, -2.3845534f, 1.0973041f, -1.0202811f, 0.0325992f, 0.3779927f, 0.0220244f, 0.0176618f, 0.5749296f, 0.0441563f, 0.1922859f, -0.1661071f, 0.3079973f, -0.0444220f, 0.2683854f, 0.6849532f, 0.1716617f, 0.5063406f, 0.7730200f, 0.0671466f, 0.5224924f, -2.8718150f, -1.4443614f, -0.1455812f, 0.0810298f, -0.3216130f, -0.3000851f, 0.7883633f, -0.6158253f, 1.2222605f },\n",
0585 "{ -0.5756851f, -0.0641467f, -0.8851644f, -1.1059726f, -1.4682759f, 0.3964167f, -0.0009879f, -0.6186015f, -0.0061868f, -0.1787315f, -0.4631928f, -0.1590177f, -0.3159569f, -0.5368351f, -0.4180341f, -0.1800581f, -0.4896149f, -1.3214008f, -2.5101123f, -0.8682919f, -0.5497130f, -0.2457729f, -0.3258826f, 0.5641282f, -0.7060823f, -0.1749261f, -0.2885247f, -0.1631423f, 0.0304128f, -0.0907670f, -0.2878378f, -0.6338934f },\n",
0586 "{ -1.0935251f, -0.9992083f, -0.0633413f, -0.1609752f, -0.3885693f, 0.3524578f, 0.0127510f, 0.1163242f, -0.1016347f, 0.0687914f, 0.0947863f, -0.0067241f, -0.1876621f, -0.0223206f, 0.1423765f, -0.1802772f, -0.0291052f, 0.0450566f, -0.4598204f, -0.1165596f, -0.3479767f, -0.0701668f, 0.2881460f, 0.1044018f, -0.0239950f, -0.1551994f, 0.2170918f, -0.2169842f, 0.4257921f, 0.3812013f, 0.4517531f, -0.1721758f },\n",
0587 "{ -0.2656136f, -1.6221433f, 0.0597067f, 0.4315053f, -0.0760466f, -1.1311982f, -0.1951133f, -0.0456391f, -0.3352800f, -0.1049919f, 0.3278280f, 0.0194480f, -0.2923800f, 1.0222012f, 0.0098873f, -0.0653058f, 0.2413959f, 0.7756509f, -0.5178695f, 0.2742799f, 0.2555766f, 0.0098521f, 0.1389581f, 1.2634234f, -0.1592175f, 0.0732114f, -0.1383710f, 0.0594422f, 0.2089978f, 0.4949785f, -1.7492013f, -0.6352664f },\n",
0588 "{ -0.4342200f, 2.3557084f, -0.3712214f, -0.7716209f, -0.4416530f, 0.1379044f, 0.0727059f, 0.5425858f, -0.0196064f, 0.1223867f, -0.0692875f, 0.1327683f, 0.4494519f, 0.0952457f, 0.1334870f, 0.0676249f, 0.3728006f, 0.7255252f, -1.6268474f, 0.8089049f, 0.1800606f, -0.1145329f, -0.1641574f, -3.0665483f, 0.5469566f, -0.0040791f, 0.0806037f, -0.0413218f, -4.4606824f, -1.1090307f, -4.0268488f, -1.1893214f },\n",
0589 "{ -0.2475177f, -0.6829165f, -0.0234359f, -0.0061992f, -0.0379326f, 0.4335875f, 0.0828241f, 0.0840699f, -0.0683035f, -0.1582910f, -0.0177477f, 0.1024636f, -0.0703338f, -0.1498718f, -0.0622758f, -0.0744936f, -0.0358834f, -0.1732212f, -1.1194228f, 0.0808573f, 0.0806586f, -0.1502637f, -0.0854174f, 0.6450295f, 0.1303551f, -0.1633945f, 0.0709481f, 0.0788564f, -0.1573547f, 0.2853298f, 0.3992899f, -0.2722275f },\n",
0590 "{ -0.0024052f, 0.0832400f, -0.1240560f, 0.1827071f, 0.5078100f, -1.9449767f, 0.1166574f, -0.4264094f, -0.3462796f, 0.0000746f, 0.0312783f, -0.1609135f, -0.2486863f, -0.0109304f, -0.3259882f, -0.0161246f, -0.1272729f, 0.3718676f, 0.7143661f, -0.0934052f, 0.2293939f, -0.1234700f, -0.5969369f, -0.0931792f, 0.1510706f, -0.1806178f, -0.3079583f, -0.0738148f, -0.3563280f, -0.2019677f, -2.4623275f, -1.5662863f },\n",
0591 "{ -1.0340914f, -0.0264978f, 0.1637420f, -0.2267127f, -0.0878844f, -0.9779887f, -0.0298158f, -0.0878086f, -0.2519222f, -0.1923419f, 0.1495816f, -0.1703948f, -0.4357837f, 0.2593633f, -0.1618164f, -0.0063047f, 0.1708476f, 0.1812407f, -0.1949099f, -0.0840897f, 0.0365015f, -0.1204635f, -0.1736511f, 0.1234083f, -0.2201642f, -0.2357494f, -0.1014172f, 0.1092978f, 0.1652685f, -0.3343832f, -0.7220409f, -1.6212125f },\n",
0592 "{ -0.2753425f, -1.0450485f, 0.0862796f, 0.5330997f, 0.2215044f, 1.0313069f, -0.1288254f, -0.4092823f, -0.2298029f, -0.1440723f, -0.4615285f, -0.0421680f, 0.2838994f, -0.4270649f, -0.4212369f, -0.0371355f, -0.4017382f, -0.3131837f, 0.0822975f, -0.1237065f, -0.3185453f, 0.1537958f, -0.2095980f, -1.3749454f, 0.8085594f, -0.1836801f, -0.2636165f, 0.0256138f, 0.2858065f, 0.2371098f, -1.3770546f, 0.1130755f },\n",
0593 "{ -0.9280756f, -0.7408227f, -0.2019696f, -0.2414151f, 0.1294307f, 0.5904933f, -0.1836856f, -0.1364715f, -0.1818316f, 0.0698972f, -0.0663465f, -0.0624342f, -0.5138503f, -0.8630795f, -0.2610115f, 0.1374615f, -0.0169794f, -0.5805362f, -0.1026141f, -0.1647637f, -0.2088510f, 0.0510566f, -0.1424105f, -3.4270186f, 0.2712817f, -0.1466969f, -0.1463187f, -0.1640311f, 0.4300542f, 0.0572317f, -0.2324502f, -1.0792063f },\n",
0594 "{ -0.0307543f, -0.0768560f, 0.8240007f, 0.7662915f, 0.1543863f, -0.1713554f, -0.1223227f, 0.4386547f, -0.2540433f, -0.1332257f, 0.7341567f, -0.0580215f, -0.2693185f, 0.5857823f, 0.4592545f, -0.0942196f, 0.5861048f, 0.5916333f, 0.8047288f, 0.6647347f, 0.6680315f, 0.0370881f, 0.5487651f, 0.4163764f, 0.1050351f, -0.0590813f, 0.7615875f, -0.0326647f, -0.3096285f, -0.0770774f, 0.4334740f, 0.0817739f },\n",
0595 "{ -0.3358265f, -1.4673307f, -0.0814484f, -0.1950698f, 0.2350309f, -0.8171171f, -0.0790631f, -0.0033560f, -0.2415340f, -0.1976242f, -0.0537354f, 0.0254026f, 0.6134555f, 0.0668075f, -0.0584858f, 0.0701805f, -0.0682784f, -0.1794098f, -0.0583040f, -0.1624253f, -0.1456280f, -0.2352183f, -0.0143530f, -0.1338016f, -2.4032581f, 0.1073783f, 0.1195123f, 0.0956468f, 0.1293423f, 1.1435740f, 0.3386673f, 0.1772921f },\n",
0596 "{ -1.6261140f, 1.0616504f, -1.4311931f, -2.4396446f, -0.9056627f, 0.4331071f, -0.0649640f, -1.3435462f, -0.0458936f, -0.0711545f, -1.6348053f, -0.1845289f, -0.0306016f, -1.4271052f, -1.3648096f, 0.0524831f, -1.7654743f, -1.1757706f, -0.6089430f, -1.3503740f, -1.7567536f, -0.0602079f, -2.0782607f, -0.3474989f, -1.6194628f, -0.1068030f, -1.7532518f, -0.1791241f, 0.8444587f, -1.6989282f, -0.3120632f, -0.3003300f },\n",
0597 "{ -0.1823801f, 3.2780595f, 0.0046086f, 0.3211953f, 2.8119521f, -2.4034190f, -0.1202310f, 0.0593807f, 0.0990012f, -0.0647286f, 0.1925517f, -0.0071298f, -0.4552369f, -0.0450624f, 0.2038983f, -0.1544811f, 0.2431807f, 2.5275569f, 0.4051182f, 0.1717180f, -4.4179425f, -0.0382397f, 0.1282281f, 1.3789687f, 0.2157716f, 0.0339806f, -0.0445110f, -0.1175756f, -0.2876654f, 0.6454542f, -4.1663256f, -4.9532986f },\n",
0598 "{ -1.5354282f, 1.1725403f, -1.3999028f, 1.4599746f, 0.7842735f, -0.2761331f, -0.1681113f, -1.2318107f, -0.1268664f, -0.1008263f, -1.1240377f, 0.0714637f, 1.0816231f, 0.0337652f, -1.0600177f, -0.0118573f, -0.9776441f, -1.1751970f, -4.1702838f, -1.4578556f, -0.5371085f, -0.1503768f, -0.7648986f, -0.1425661f, -3.3017726f, 0.1108525f, -1.1741784f, -0.0334351f, -0.5934044f, -0.0111621f, -0.9253456f, 0.9225865f },\n",
0599 "{ 0.1424126f, -0.1699037f, 0.0217177f, 0.1220412f, -0.0906500f, 0.0009671f, -0.1639844f, -0.0527305f, 0.0128741f, -0.1168858f, -0.0302848f, -0.1783690f, 0.0221473f, 0.1693649f, 0.2265290f, 0.0589060f, 0.0021857f, 0.0242652f, -0.0775702f, 0.0388727f, 0.0707951f, -0.1567841f, 0.1732654f, 0.0437058f, -0.1407849f, -0.0657614f, -0.1040369f, 0.1652682f, -0.1807958f, 0.1050604f, -0.1204055f, 0.0295164f },\n",
0600 "{ -0.0865003f, -0.5563657f, 0.0262170f, 0.1092112f, 0.1299964f, -0.2665305f, -0.0594379f, -0.3092700f, -0.1140931f, -0.1025264f, -0.0924060f, 0.0661781f, -0.6323586f, -0.1840952f, -0.3743987f, 0.0433477f, 0.1879622f, -0.2726538f, 0.0839276f, 0.0398268f, -0.2266287f, -0.0569720f, 0.1651291f, 0.0004582f, 0.0672146f, -0.0313096f, 0.0879014f, 0.0798023f, -0.6620474f, -0.3207071f, 0.0426169f, -0.2612755f },\n",
0601 "{ -0.0320727f, -1.2532046f, 0.2544250f, -0.2982330f, -0.4790863f, -0.5560963f, -0.1201061f, -0.0375240f, -0.1868488f, 0.0521472f, -0.0850537f, -0.1624364f, 0.1665921f, -0.1470270f, 0.2794432f, -0.1671568f, 0.0176426f, -0.4760583f, -0.1061204f, 0.3590186f, 0.1793246f, -0.0764172f, 0.3327765f, 1.3718973f, -0.6579911f, -0.2003471f, -0.0251827f, -0.0495654f, 0.0083008f, -0.2771839f, -0.8300866f, -3.0943394f },\n",
0602 "{ 0.0621519f, -0.1735040f, 0.1486671f, -0.1150359f, 0.1008294f, -0.0092464f, 0.0799274f, 0.1733976f, 0.0848012f, -0.1365137f, -0.0797485f, 0.1167114f, 0.1286604f, -0.0544559f, 0.0424396f, -0.1382235f, -0.0363623f, 0.0376997f, -0.0940855f, -0.0822849f, 0.1207218f, 0.0484856f, -0.0699596f, 0.1332577f, -0.1111775f, -0.1007963f, -0.0908793f, -0.1231520f, -0.0480829f, 0.0872541f, 0.0143959f, 0.1434527f },\n",
0603 "{ -0.7731760f, -0.1756437f, -0.5544511f, -1.1350045f, -0.8066334f, 0.4560536f, 0.0514501f, -0.5005841f, -0.1125634f, -0.1759978f, -0.5758266f, -0.1267792f, -0.1993223f, -0.3803072f, -0.4573608f, -0.0674755f, -0.4743742f, -0.6225397f, -1.0214270f, -0.5803681f, -1.1637334f, -0.0488763f, -0.4069555f, 0.3787926f, -1.1097326f, -0.1770509f, -0.3172881f, 0.0409949f, -0.0639247f, -0.3521400f, 0.4321971f, 0.0638231f },\n",
0604 "{ 1.1334474f, 0.2912367f, 0.9589775f, 1.3626057f, -0.5325903f, 0.0380656f, -0.1852983f, 0.4992394f, 0.0062829f, 0.1194302f, 0.6635179f, -0.1749409f, 0.9010410f, 0.4779315f, 0.6941727f, -0.0554726f, 0.6099629f, 0.6952612f, -0.6230863f, 0.5901819f, 0.8190935f, -0.3441652f, 0.6224667f, 0.4739881f, 0.6536815f, 0.0891643f, 0.6636505f, -0.1713116f, -0.1173943f, 1.0364656f, 0.1431172f, 0.4249419f },\n",
0605 "{ 0.8033506f, 0.0300558f, 0.6656652f, 1.2101313f, -0.8615839f, 0.5132746f, -0.0265759f, 0.5377895f, 0.0152243f, 0.0114079f, 0.6099907f, -0.1249313f, 0.2723404f, 0.3610747f, 0.8165271f, -0.1239832f, 0.4873835f, 0.8119840f, 1.0779978f, 0.7688931f, 1.2414374f, 0.0522347f, 0.5015374f, 0.3038165f, 0.7020112f, -0.5665528f, 0.5736545f, -0.2835588f, 0.2212313f, 0.1219948f, 0.3844957f, 0.3401546f },\n",
0606 "{ -0.7439671f, 0.1166124f, -0.5117782f, -0.8712475f, 0.5832991f, 0.8349549f, -0.0009121f, -0.2292912f, -0.1560434f, 0.1009018f, -0.3865246f, 0.0456842f, -0.0389204f, -0.4212565f, -0.4042930f, -0.0817643f, -0.4844481f, -0.5429400f, -0.0883068f, -0.2933508f, -1.4754379f, -0.0470565f, -0.3384019f, -0.2567133f, -1.3822908f, -0.4623447f, -0.3862867f, -0.1395572f, 0.4184268f, -0.4738576f, 0.5919805f, 0.4029315f },\n",
0607 "{ 0.0951967f, -0.0264262f, 0.0745322f, 0.1399388f, 0.1192624f, -0.0836915f, 0.1539033f, -0.1881727f, 0.1124343f, 0.1572933f, -0.0205816f, -0.0527924f, -0.0939525f, -0.0599392f, 0.0312916f, -0.1276211f, -0.0861412f, -0.2229238f, 0.0862552f, 0.1315805f, 0.1222533f, 0.1245137f, 0.0553789f, -0.0982146f, -0.1024362f, -0.0534966f, -0.1330877f, -0.0255273f, -0.0186730f, -0.0708270f, 0.1567561f, -0.0844505f },\n",
0608 "{ -0.0146923f, 0.2591865f, -0.2036884f, -0.0632321f, 1.3047949f, -0.5502592f, -0.1978513f, -0.0660495f, 0.0529672f, -0.0396357f, -0.5039385f, 0.0635856f, 0.2173319f, -0.2249199f, -0.2624365f, 0.0801869f, -0.5361095f, -0.7585190f, 0.2618334f, -0.3435789f, -0.0586956f, -0.1631703f, -0.4265099f, -0.0447214f, -1.0177643f, 0.0323080f, -0.3818924f, -0.1746079f, 0.5966771f, 0.6934303f, -0.2669153f, -0.4290505f },\n",
0609 "{ -0.7751231f, -0.3161191f, -0.2155289f, -0.6773937f, 0.1730236f, -0.4401509f, -0.1782488f, -0.2299091f, -0.2481384f, -0.1676966f, -0.2118478f, -0.1280043f, -0.3889951f, -0.2434384f, -0.3576782f, -0.1191592f, 0.0488413f, -0.3035202f, -0.1206791f, -0.3636490f, -0.4272992f, -0.0747663f, 0.1205972f, 0.3637502f, -0.4497322f, -0.0998606f, -0.2597706f, -0.0395944f, -0.1047684f, -0.1457169f, -0.3183976f, -0.4862959f },\n",
0610 "{ -0.3920153f, 0.0499827f, 0.0303043f, 0.1732283f, -5.3812666f, 0.5284207f, -0.0502357f, 0.4224207f, 0.0811522f, 0.0420046f, -0.0458250f, -0.1272357f, 0.0563048f, 0.0414863f, 0.3250322f, -0.1368608f, 0.2093009f, 0.2632169f, 0.2363917f, 0.2263645f, 0.2523394f, -0.0778029f, 0.0646684f, 0.4740721f, 0.3888938f, 0.0414786f, 0.1371806f, 0.0576984f, 0.2861594f, -0.3092689f, 0.3796774f, 0.4443598f },\n",
0611 "{ -0.8123505f, 0.2358683f, 0.1019373f, 0.2282849f, -1.2557510f, 0.1555537f, -0.0974192f, 0.5895486f, 0.0036116f, -0.1097254f, 0.1486044f, -0.0238196f, -0.6786419f, 0.3634615f, 0.6986865f, 0.0832033f, 0.2520329f, -2.2602396f, 0.1090904f, 0.7121069f, -1.6000881f, -0.0334629f, 0.5868925f, -0.1816964f, 0.4748906f, 0.0168051f, 0.6019380f, -0.1252660f, -0.6808432f, 0.4272106f, -0.0474080f, 1.0566397f },\n",
0612 "{ 0.0639796f, -0.0572261f, 0.0528333f, -0.0733780f, -0.0679085f, 0.1477537f, 0.1640870f, -0.0076429f, -0.1674436f, -0.0890600f, -0.0843337f, 0.0499240f, 0.1675422f, 0.0999710f, -0.0753364f, -0.1303664f, 0.0212046f, 0.1698232f, -0.0764421f, 0.0027978f, 0.0988120f, 0.0586733f, -0.0051097f, -0.0913044f, -0.0019131f, 0.1084543f, -0.0796013f, -0.0472390f, 0.1328194f, -0.1429810f, -0.1264632f, -0.1601117f },\n",
0613 "{ 0.6985479f, -0.0771215f, -0.0554999f, 0.6879453f, 1.2087054f, 0.5267281f, -0.0005624f, -0.4339048f, 0.0174754f, -0.0884810f, 0.0628883f, 0.0359151f, 1.9782046f, 0.0839531f, -0.0512894f, 0.1214923f, 0.1118447f, -0.2496938f, 1.8209995f, -0.1433189f, -0.6471683f, -0.3243633f, -0.2157310f, -9.1764698f, 0.9299871f, -0.1122522f, -0.0976612f, 0.0127003f, 0.8123280f, 0.9145623f, -1.1885530f, 3.2752204f },\n",
0614 "};\n",
0615 "\n",
0616 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_output_layer[1] = {\n",
0617 "-0.2746492f };\n",
0618 "\n",
0619 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_output_layer[32][1] = {\n",
0620 "{ -3.9178145f },\n",
0621 "{ 2.2943201f },\n",
0622 "{ 0.9315882f },\n",
0623 "{ -5.2151937f },\n",
0624 "{ 2.5372679f },\n",
0625 "{ -1.1143198f },\n",
0626 "{ 0.1008911f },\n",
0627 "{ 0.3630357f },\n",
0628 "{ -0.1294091f },\n",
0629 "{ 0.1117750f },\n",
0630 "{ 0.6620483f },\n",
0631 "{ 0.0155139f },\n",
0632 "{ -0.4766000f },\n",
0633 "{ 0.2999536f },\n",
0634 "{ 0.7013811f },\n",
0635 "{ -0.0866368f },\n",
0636 "{ 0.5150933f },\n",
0637 "{ 1.5360959f },\n",
0638 "{ 1.9393219f },\n",
0639 "{ 0.6595656f },\n",
0640 "{ -5.3660374f },\n",
0641 "{ 0.0038123f },\n",
0642 "{ 0.6477750f },\n",
0643 "{ 1.6103860f },\n",
0644 "{ -3.5332921f },\n",
0645 "{ 0.1317881f },\n",
0646 "{ 0.6166227f },\n",
0647 "{ 0.0163189f },\n",
0648 "{ -0.3913481f },\n",
0649 "{ -1.1696485f },\n",
0650 "{ -1.3807020f },\n",
0651 "{ -1.1326467f },\n",
0652 "};\n",
0653 "\n"
0654 ]
0655 }
0656 ],
0657 "source": [
0658 "def print_formatted_weights_biases(weights, biases, layer_name):\n",
0659 " # Print biases\n",
0660 " print(f\"ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_{layer_name}[{len(biases)}] = {{\")\n",
0661 " print(\", \".join(f\"{b:.7f}f\" for b in biases) + \" };\")\n",
0662 " print()\n",
0663 "\n",
0664 " # Print weights\n",
0665 " print(f\"ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_{layer_name}[{len(weights[0])}][{len(weights)}] = {{\")\n",
0666 " for row in weights.T:\n",
0667 " formatted_row = \", \".join(f\"{w:.7f}f\" for w in row)\n",
0668 " print(f\"{{ {formatted_row} }},\")\n",
0669 " print(\"};\")\n",
0670 " print()\n",
0671 "\n",
0672 "def print_model_weights_biases(model):\n",
0673 " # Make sure the model is in evaluation mode\n",
0674 " model.eval()\n",
0675 "\n",
0676 " # Iterate through all named modules in the model\n",
0677 " for name, module in model.named_modules():\n",
0678 " # Check if the module is a linear layer\n",
0679 " if isinstance(module, nn.Linear):\n",
0680 " # Get weights and biases\n",
0681 " weights = module.weight.data.cpu().numpy()\n",
0682 " biases = module.bias.data.cpu().numpy()\n",
0683 "\n",
0684 " # Print formatted weights and biases\n",
0685 " print_formatted_weights_biases(weights, biases, name.replace('.', '_'))\n",
0686 "\n",
0687 "print_model_weights_biases(model)\n"
0688 ]
0689 },
0690 {
0691 "cell_type": "code",
0692 "execution_count": 7,
0693 "metadata": {},
0694 "outputs": [
0695 {
0696 "name": "stderr",
0697 "output_type": "stream",
0698 "text": [
0699 "/tmp/ipykernel_1882258/1646812576.py:7: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
0700 " inputs = torch.tensor(features, dtype=torch.float32).to('cpu')\n"
0701 ]
0702 },
0703 {
0704 "data": {
0705 "image/png": "",
0706 "text/plain": [
0707 "<Figure size 640x480 with 1 Axes>"
0708 ]
0709 },
0710 "metadata": {},
0711 "output_type": "display_data"
0712 }
0713 ],
0714 "source": [
0715 "from sklearn.metrics import roc_curve, auc\n",
0716 "import matplotlib.pyplot as plt\n",
0717 "\n",
0718 "def model_outputs(features, model):\n",
0719 " model.eval() # Set the model to evaluation mode\n",
0720 " with torch.no_grad():\n",
0721 " inputs = torch.tensor(features, dtype=torch.float32).to('cpu')\n",
0722 " outputs = model(inputs).squeeze().cpu().numpy()\n",
0723 " return outputs\n",
0724 "\n",
0725 "# Calculate model outputs\n",
0726 "probabilities = model_outputs(filtered_inputs, model)\n",
0727 "\n",
0728 "# Calculate ROC curve and AUC\n",
0729 "fpr, tpr, thresholds = roc_curve(filtered_labels, probabilities)\n",
0730 "roc_auc = auc(fpr, tpr)\n",
0731 "\n",
0732 "# Plot ROC curve\n",
0733 "plt.figure()\n",
0734 "lw = 2 # Line width\n",
0735 "plt.plot(fpr, tpr, color='darkorange', lw=lw, label='ROC curve (area = %0.3f)' % roc_auc)\n",
0736 "plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')\n",
0737 "plt.xlim([0.0, 1.0])\n",
0738 "plt.ylim([0.0, 1.05])\n",
0739 "plt.xlabel('False Positive Rate')\n",
0740 "plt.ylabel('True Positive Rate')\n",
0741 "plt.title('Receiver Operating Characteristic')\n",
0742 "plt.legend(loc=\"lower right\")\n",
0743 "plt.show()"
0744 ]
0745 },
0746 {
0747 "cell_type": "code",
0748 "execution_count": 8,
0749 "metadata": {},
0750 "outputs": [],
0751 "source": [
0752 "# Ensure input_features_tensor is moved to the appropriate device\n",
0753 "input_features_tensor = input_features_tensor.to('cpu')\n",
0754 "\n",
0755 "# Make predictions\n",
0756 "with torch.no_grad():\n",
0757 " model.eval()\n",
0758 " outputs = model(input_features_tensor)\n",
0759 " predictions = outputs.squeeze().cpu().numpy()\n",
0760 "\n",
0761 "full_tracks = (np.concatenate(branches['pT3_isFake']) == 0)"
0762 ]
0763 },
0764 {
0765 "cell_type": "code",
0766 "execution_count": 11,
0767 "metadata": {},
0768 "outputs": [
0769 {
0770 "data": {
0771 "image/png": "",
0772 "text/plain": [
0773 "<Figure size 1000x600 with 2 Axes>"
0774 ]
0775 },
0776 "metadata": {},
0777 "output_type": "display_data"
0778 },
0779 {
0780 "name": "stdout",
0781 "output_type": "stream",
0782 "text": [
0783 "\n",
0784 "pt: 0 to 5\n",
0785 "90% Retention Cut: {0.9204, 0.9421, 0.9441, 0.9532, 0.961, 0.9352, 0.8842, 0.8606, 0.8345, 0.7997} Mean: 0.9035\n",
0786 "93% Retention Cut: {0.8482, 0.8907, 0.8972, 0.9145, 0.943, 0.9027, 0.8286, 0.8234, 0.7881, 0.7249} Mean: 0.8561\n",
0787 "95% Retention Cut: {0.7417, 0.8096, 0.8166, 0.8495, 0.9192, 0.8586, 0.7761, 0.7821, 0.7431, 0.6526} Mean: 0.7949\n",
0788 "98% Retention Cut: {0.4325, 0.449, 0.5454, 0.6257, 0.8199, 0.6123, 0.5456, 0.5901, 0.5441, 0.393} Mean: 0.5558\n",
0789 "99% Retention Cut: {0.2897, 0.2835, 0.3589, 0.4254, 0.6923, 0.4273, 0.3172, 0.248, 0.4203, 0.2651} Mean: 0.3728\n",
0790 "99.5% Retention Cut: {0.189, 0.1805, 0.2267, 0.3104, 0.4719, 0.3159, 0.1372, 0.1571, 0.3198, 0.186} Mean: 0.2495\n"
0791 ]
0792 },
0793 {
0794 "data": {
0795 "image/png": "",
0796 "text/plain": [
0797 "<Figure size 1000x600 with 2 Axes>"
0798 ]
0799 },
0800 "metadata": {},
0801 "output_type": "display_data"
0802 },
0803 {
0804 "name": "stdout",
0805 "output_type": "stream",
0806 "text": [
0807 "\n",
0808 "pt: 5 to inf\n",
0809 "90% Retention Cut: {0.5279} Mean: 0.5279\n",
0810 "93% Retention Cut: {0.3964} Mean: 0.3964\n",
0811 "95% Retention Cut: {0.2557} Mean: 0.2557\n",
0812 "98% Retention Cut: {0.0473} Mean: 0.0473\n",
0813 "99% Retention Cut: {0.0091} Mean: 0.0091\n",
0814 "99.5% Retention Cut: {0.0024} Mean: 0.0024\n"
0815 ]
0816 }
0817 ],
0818 "source": [
0819 "import numpy as np\n",
0820 "from matplotlib import pyplot as plt\n",
0821 "from matplotlib.colors import LogNorm\n",
0822 "\n",
0823 "def plot_for_pt_bin(pt_min, pt_max, percentiles, eta_bin_edges, eta_list, predictions, full_tracks, branches):\n",
0824 " \"\"\"\n",
0825 " Calculate and plot cut values for specified percentiles in a given pt bin\n",
0826 " \n",
0827 " Parameters:\n",
0828 " -----------\n",
0829 " pt_min : float\n",
0830 " Minimum pt value for the bin\n",
0831 " pt_max : float\n",
0832 " Maximum pt value for the bin\n",
0833 " percentiles : list\n",
0834 " List of percentiles to calculate (e.g., [92.5, 96.7, 99])\n",
0835 " eta_bin_edges : array\n",
0836 " Edges of the eta bins\n",
0837 " eta_list : list\n",
0838 " List of eta values\n",
0839 " predictions : array\n",
0840 " Array of DNN predictions\n",
0841 " full_tracks : array\n",
0842 " Boolean array for track selection\n",
0843 " branches : dict\n",
0844 " Dictionary containing branch data\n",
0845 " \"\"\"\n",
0846 " # Filter data based on pt bin\n",
0847 " abs_eta = eta_list[0][full_tracks & (np.concatenate(branches['pT3_pt']) > pt_min) & \n",
0848 " (np.concatenate(branches['pT3_pt']) <= pt_max)]\n",
0849 " predictions_filtered = predictions[full_tracks & (np.concatenate(branches['pT3_pt']) > pt_min) & \n",
0850 " (np.concatenate(branches['pT3_pt']) <= pt_max)]\n",
0851 " \n",
0852 " # Dictionary to store cut values for different percentiles\n",
0853 " cut_values = {p: [] for p in percentiles}\n",
0854 "\n",
0855 " # Loop through each eta bin\n",
0856 " for i in range(len(eta_bin_edges) - 1):\n",
0857 " # Get indices of tracks within the current eta bin\n",
0858 " bin_indices = (abs_eta >= eta_bin_edges[i]) & (abs_eta < eta_bin_edges[i + 1])\n",
0859 " \n",
0860 " # Get the corresponding DNN prediction scores\n",
0861 " bin_predictions = predictions_filtered[bin_indices]\n",
0862 " \n",
0863 " # Calculate the percentile cut values for the current bin\n",
0864 " for percentile in percentiles:\n",
0865 " cut_value = np.percentile(bin_predictions, 100 - percentile) # Convert retention to percentile\n",
0866 " cut_values[percentile].append(cut_value)\n",
0867 "\n",
0868 " # Plot 2D histogram\n",
0869 " plt.figure(figsize=(10, 6))\n",
0870 " plt.hist2d(abs_eta, predictions_filtered, bins=[eta_bin_edges, 50], norm=LogNorm())\n",
0871 " plt.colorbar(label='Counts')\n",
0872 " plt.xlabel(\"Absolute Eta\")\n",
0873 " plt.ylabel(\"DNN Prediction Score\")\n",
0874 " plt.title(f\"DNN Score vs. Abs Eta for 100% Matched Tracks (pt: {pt_min} to {pt_max})\")\n",
0875 "\n",
0876 " # Plot the cut values with different colors\n",
0877 " cut_x = eta_bin_edges[:-1] + (eta_bin_edges[1] - eta_bin_edges[0]) / 2 # Mid-points of the bins\n",
0878 " colors = plt.cm.rainbow(np.linspace(0, 1, len(percentiles))) # Generate distinct colors\n",
0879 " \n",
0880 " for percentile, color in zip(percentiles, colors):\n",
0881 " plt.plot(cut_x, cut_values[percentile], '-', color=color, marker='o', \n",
0882 " label=f'{percentile}% Retention Cut')\n",
0883 " \n",
0884 " plt.legend()\n",
0885 " plt.grid(True, alpha=0.3)\n",
0886 " plt.show()\n",
0887 " \n",
0888 " # Print the cut values\n",
0889 " print(f\"\\npt: {pt_min} to {pt_max}\")\n",
0890 " for percentile in percentiles:\n",
0891 " values = cut_values[percentile]\n",
0892 " print(f\"{percentile}% Retention Cut:\", \n",
0893 " '{' + ', '.join(str(x) for x in np.round(values, 4)) + '}',\n",
0894 " \"Mean:\", np.round(np.mean(values), 4))\n",
0895 "\n",
0896 "# Example usage:\n",
0897 "def analyze_pt_bins(pt_bins, percentiles, eta_bin_edges, eta_list, predictions, full_tracks, branches):\n",
0898 " \"\"\"\n",
0899 " Analyze and plot for multiple pt bins and percentiles\n",
0900 " \n",
0901 " Parameters:\n",
0902 " -----------\n",
0903 " pt_bins : list\n",
0904 " List of pt bin edges\n",
0905 " percentiles : list\n",
0906 " List of percentiles to calculate\n",
0907 " Other parameters same as plot_for_pt_bin function\n",
0908 " \"\"\"\n",
0909 " for i in range(len(pt_bins) - 1):\n",
0910 " plot_for_pt_bin(pt_bins[i], pt_bins[i + 1], percentiles, eta_bin_edges, \n",
0911 " eta_list, predictions, full_tracks, branches)\n",
0912 "\n",
0913 "percentiles = [90, 93, 95, 98, 99, 99.5]\n",
0914 "\n",
0915 "# For pt <= 5 using multiple eta bins\n",
0916 "pt_bins_low = [0, 5]\n",
0917 "analyze_pt_bins(pt_bins_low, percentiles, np.arange(0, 2.75, 0.25), eta_list, predictions, full_tracks, branches)\n",
0918 "\n",
0919 "# For pt > 5 using a single eta bin\n",
0920 "pt_bins_high = [5, np.inf]\n",
0921 "single_eta_bin = np.array([0, 2.75])\n",
0922 "analyze_pt_bins(pt_bins_high, percentiles, single_eta_bin, eta_list, predictions, full_tracks, branches)"
0923 ]
0924 },
0925 {
0926 "cell_type": "code",
0927 "execution_count": null,
0928 "metadata": {},
0929 "outputs": [],
0930 "source": []
0931 }
0932 ],
0933 "metadata": {
0934 "kernelspec": {
0935 "display_name": "analysisenv",
0936 "language": "python",
0937 "name": "python3"
0938 },
0939 "language_info": {
0940 "codemirror_mode": {
0941 "name": "ipython",
0942 "version": 3
0943 },
0944 "file_extension": ".py",
0945 "mimetype": "text/x-python",
0946 "name": "python",
0947 "nbconvert_exporter": "python",
0948 "pygments_lexer": "ipython3",
0949 "version": "3.11.7"
0950 }
0951 },
0952 "nbformat": 4,
0953 "nbformat_minor": 2
0954 }