Warning, /RecoTracker/LSTCore/standalone/analysis/DNN/train_T3_DNN.ipynb is written in an unsupported language. File is not indexed.
0001 {
0002 "cells": [
0003 {
0004 "cell_type": "code",
0005 "execution_count": 1,
0006 "metadata": {},
0007 "outputs": [],
0008 "source": [
0009 "import os\n",
0010 "import uproot\n",
0011 "import numpy as np\n",
0012 "\n",
0013 "def load_root_file(file_path, branches=None, print_branches=False):\n",
0014 " all_branches = {}\n",
0015 " with uproot.open(file_path) as file:\n",
0016 " tree = file[\"tree\"]\n",
0017 " # Load all ROOT branches into array if not specified\n",
0018 " if branches is None:\n",
0019 " branches = tree.keys()\n",
0020 " # Option to print the branch names\n",
0021 " if print_branches:\n",
0022 " print(\"Branches:\", tree.keys())\n",
0023 " # Each branch is added to the dictionary\n",
0024 " for branch in branches:\n",
0025 " try:\n",
0026 " all_branches[branch] = (tree[branch].array(library=\"np\"))\n",
0027 " except uproot.KeyInFileError as e:\n",
0028 " print(f\"KeyInFileError: {e}\")\n",
0029 " # Number of events in file\n",
0030 " all_branches['event'] = tree.num_entries\n",
0031 " return all_branches\n",
0032 "\n",
0033 "branches_list = [\n",
0034 " # Core T3 properties from TripletsSoA\n",
0035 " 't3_betaIn',\n",
0036 " 't3_centerX',\n",
0037 " 't3_centerY',\n",
0038 " 't3_radius',\n",
0039 " 't3_partOfPT5',\n",
0040 " 't3_partOfT5',\n",
0041 " 't3_partOfPT3',\n",
0042 " 't3_layer_binary',\n",
0043 " 't3_pMatched',\n",
0044 " 't3_matched_simIdx',\n",
0045 " 't3_sim_vxy',\n",
0046 " 't3_sim_vz'\n",
0047 "]\n",
0048 "\n",
0049 "# Hit-dependent branches\n",
0050 "suffixes = ['r', 'z', 'eta', 'phi', 'layer']\n",
0051 "branches_list += [f't3_hit_{i}_{suffix}' for i in [0, 1, 2, 3, 4, 5] for suffix in suffixes]\n",
0052 "\n",
0053 "file_path = \"600_t3_dnn_relval_fix.root\"\n",
0054 "branches = load_root_file(file_path, branches_list)"
0055 ]
0056 },
0057 {
0058 "cell_type": "code",
0059 "execution_count": 2,
0060 "metadata": {},
0061 "outputs": [
0062 {
0063 "name": "stdout",
0064 "output_type": "stream",
0065 "text": [
0066 "Z max: 224.14950561523438, R max: 98.93299102783203, Eta max: 2.5\n"
0067 ]
0068 }
0069 ],
0070 "source": [
0071 "z_max = np.max([np.max(event) for event in branches[f't3_hit_3_z']])\n",
0072 "r_max = np.max([np.max(event) for event in branches[f't3_hit_3_r']])\n",
0073 "eta_max = 2.5\n",
0074 "phi_max = np.pi\n",
0075 "\n",
0076 "print(f'Z max: {z_max}, R max: {r_max}, Eta max: {eta_max}')"
0077 ]
0078 },
0079 {
0080 "cell_type": "code",
0081 "execution_count": 3,
0082 "metadata": {},
0083 "outputs": [],
0084 "source": [
0085 "def delta_phi(phi1, phi2):\n",
0086 " delta = phi1 - phi2\n",
0087 " delta = np.where(delta > np.pi, delta - 2*np.pi, delta)\n",
0088 " delta = np.where(delta < -np.pi, delta + 2*np.pi, delta)\n",
0089 " return delta\n",
0090 "\n",
0091 "n_events = branches['event']\n",
0092 "\n",
0093 "all_eta0 = np.abs(np.concatenate([branches['t3_hit_0_eta'][evt] for evt in range(n_events)]))\n",
0094 "all_eta2 = np.abs(np.concatenate([branches['t3_hit_2_eta'][evt] for evt in range(n_events)]))\n",
0095 "all_eta4 = np.abs(np.concatenate([branches['t3_hit_4_eta'][evt] for evt in range(n_events)]))\n",
0096 "\n",
0097 "all_phi0 = np.concatenate([branches['t3_hit_0_phi'][evt] for evt in range(n_events)])\n",
0098 "all_phi2 = np.concatenate([branches['t3_hit_2_phi'][evt] for evt in range(n_events)])\n",
0099 "all_phi4 = np.concatenate([branches['t3_hit_4_phi'][evt] for evt in range(n_events)])\n",
0100 "\n",
0101 "all_z0 = np.abs(np.concatenate([branches['t3_hit_0_z'][evt] for evt in range(n_events)]))\n",
0102 "all_z2 = np.abs(np.concatenate([branches['t3_hit_2_z'][evt] for evt in range(n_events)]))\n",
0103 "all_z4 = np.abs(np.concatenate([branches['t3_hit_4_z'][evt] for evt in range(n_events)]))\n",
0104 "\n",
0105 "all_r0 = np.concatenate([branches['t3_hit_0_r'][evt] for evt in range(n_events)])\n",
0106 "all_r2 = np.concatenate([branches['t3_hit_2_r'][evt] for evt in range(n_events)])\n",
0107 "all_r4 = np.concatenate([branches['t3_hit_4_r'][evt] for evt in range(n_events)])\n",
0108 "\n",
0109 "all_radius = np.concatenate([branches['t3_radius'][evt] for evt in range(n_events)])\n",
0110 "all_betaIn = np.concatenate([branches['t3_betaIn'][evt] for evt in range(n_events)])\n",
0111 "\n",
0112 "features = np.array([\n",
0113 " all_eta0 / eta_max, # Hit 0 eta\n",
0114 " np.abs(all_phi0) / phi_max, # Hit 0 phi\n",
0115 " all_z0 / z_max, # Hit 0 z\n",
0116 " all_r0 / r_max, # Hit 0 r\n",
0117 " (all_eta2 - all_eta0), # Difference in eta: hit2 - hit0\n",
0118 " delta_phi(all_phi2, all_phi0) / phi_max, # Difference in phi: hit2 - hit0\n",
0119 " (all_z2 - all_z0) / z_max, # Difference in z: hit2 - hit0\n",
0120 " (all_r2 - all_r0) / r_max, # Difference in r: hit2 - hit0\n",
0121 " (all_eta4 - all_eta2), # Difference in eta: hit4 - hit2\n",
0122 " delta_phi(all_phi4, all_phi2) / phi_max, # Difference in phi: hit4 - hit2\n",
0123 " (all_z4 - all_z2) / z_max, # Difference in z: hit4 - hit2\n",
0124 " (all_r4 - all_r2) / r_max, # Difference in r: hit4 - hit2\n",
0125 " np.log10(all_radius), # Circle radius\n",
0126 " all_betaIn # Beta angle\n",
0127 "])\n",
0128 "\n",
0129 "eta_list = np.array([all_eta0])"
0130 ]
0131 },
0132 {
0133 "cell_type": "code",
0134 "execution_count": 4,
0135 "metadata": {},
0136 "outputs": [],
0137 "source": [
0138 "import torch\n",
0139 "from torch import nn\n",
0140 "from torch.optim import Adam\n",
0141 "from torch.utils.data import DataLoader, TensorDataset, random_split\n",
0142 "import numpy as np\n",
0143 "\n",
0144 "# Stack features and handle NaN/Inf as before\n",
0145 "input_features_numpy = np.stack(features, axis=-1)\n",
0146 "mask = ~np.isnan(input_features_numpy) & ~np.isinf(input_features_numpy)\n",
0147 "filtered_input_features_numpy = input_features_numpy[np.all(mask, axis=1)]\n",
0148 "t3_isFake_filtered = (np.concatenate(branches['t3_pMatched']) < 0.75)[np.all(mask, axis=1)]\n",
0149 "t3_sim_vxy_filtered = np.concatenate(branches['t3_sim_vxy'])[np.all(mask, axis=1)]\n",
0150 "\n",
0151 "# Convert to PyTorch tensor\n",
0152 "input_features_tensor = torch.tensor(filtered_input_features_numpy, dtype=torch.float32)"
0153 ]
0154 },
0155 {
0156 "cell_type": "code",
0157 "execution_count": 5,
0158 "metadata": {},
0159 "outputs": [
0160 {
0161 "name": "stdout",
0162 "output_type": "stream",
0163 "text": [
0164 "Using device: cuda\n",
0165 "Initial dataset size: 55072926\n",
0166 "Class distribution before downsampling - Fake: 49829032.0, Prompt: 4472777.0, Displaced: 771119.0\n",
0167 "Class distribution after downsampling - Fake: 9965806.0, Prompt: 4472777.0, Displaced: 771119.0\n",
0168 "Epoch [1/400], Train Loss: 0.6515, Test Loss: 0.5908\n",
0169 "Epoch [2/400], Train Loss: 0.5771, Test Loss: 0.5659\n",
0170 "Epoch [3/400], Train Loss: 0.5647, Test Loss: 0.5549\n",
0171 "Epoch [4/400], Train Loss: 0.5588, Test Loss: 0.5627\n",
0172 "Epoch [5/400], Train Loss: 0.5553, Test Loss: 0.5541\n",
0173 "Epoch [6/400], Train Loss: 0.5528, Test Loss: 0.5664\n",
0174 "Epoch [7/400], Train Loss: 0.5508, Test Loss: 0.5574\n",
0175 "Epoch [8/400], Train Loss: 0.5492, Test Loss: 0.5503\n",
0176 "Epoch [9/400], Train Loss: 0.5478, Test Loss: 0.5447\n",
0177 "Epoch [10/400], Train Loss: 0.5472, Test Loss: 0.5538\n",
0178 "Epoch [11/400], Train Loss: 0.5459, Test Loss: 0.5534\n",
0179 "Epoch [12/400], Train Loss: 0.5454, Test Loss: 0.5487\n",
0180 "Epoch [13/400], Train Loss: 0.5445, Test Loss: 0.5366\n",
0181 "Epoch [14/400], Train Loss: 0.5441, Test Loss: 0.5387\n",
0182 "Epoch [15/400], Train Loss: 0.5434, Test Loss: 0.5421\n",
0183 "Epoch [16/400], Train Loss: 0.5426, Test Loss: 0.5397\n",
0184 "Epoch [17/400], Train Loss: 0.5420, Test Loss: 0.5486\n",
0185 "Epoch [18/400], Train Loss: 0.5412, Test Loss: 0.5398\n",
0186 "Epoch [19/400], Train Loss: 0.5409, Test Loss: 0.5421\n",
0187 "Epoch [20/400], Train Loss: 0.5405, Test Loss: 0.5499\n",
0188 "Epoch [21/400], Train Loss: 0.5399, Test Loss: 0.5573\n",
0189 "Epoch [22/400], Train Loss: 0.5396, Test Loss: 0.5388\n",
0190 "Epoch [23/400], Train Loss: 0.5393, Test Loss: 0.5399\n",
0191 "Epoch [24/400], Train Loss: 0.5388, Test Loss: 0.5391\n",
0192 "Epoch [25/400], Train Loss: 0.5383, Test Loss: 0.5375\n",
0193 "Epoch [26/400], Train Loss: 0.5381, Test Loss: 0.5386\n",
0194 "Epoch [27/400], Train Loss: 0.5380, Test Loss: 0.5454\n",
0195 "Epoch [28/400], Train Loss: 0.5376, Test Loss: 0.5350\n",
0196 "Epoch [29/400], Train Loss: 0.5376, Test Loss: 0.5572\n",
0197 "Epoch [30/400], Train Loss: 0.5371, Test Loss: 0.5486\n",
0198 "Epoch [31/400], Train Loss: 0.5364, Test Loss: 0.5331\n",
0199 "Epoch [32/400], Train Loss: 0.5368, Test Loss: 0.5498\n",
0200 "Epoch [33/400], Train Loss: 0.5367, Test Loss: 0.5363\n",
0201 "Epoch [34/400], Train Loss: 0.5363, Test Loss: 0.5377\n",
0202 "Epoch [35/400], Train Loss: 0.5360, Test Loss: 0.5413\n",
0203 "Epoch [36/400], Train Loss: 0.5361, Test Loss: 0.5385\n",
0204 "Epoch [37/400], Train Loss: 0.5357, Test Loss: 0.5433\n",
0205 "Epoch [38/400], Train Loss: 0.5353, Test Loss: 0.5404\n",
0206 "Epoch [39/400], Train Loss: 0.5353, Test Loss: 0.5328\n",
0207 "Epoch [40/400], Train Loss: 0.5349, Test Loss: 0.5363\n",
0208 "Epoch [41/400], Train Loss: 0.5348, Test Loss: 0.5371\n",
0209 "Epoch [42/400], Train Loss: 0.5347, Test Loss: 0.5349\n",
0210 "Epoch [43/400], Train Loss: 0.5344, Test Loss: 0.5356\n",
0211 "Epoch [44/400], Train Loss: 0.5342, Test Loss: 0.5355\n",
0212 "Epoch [45/400], Train Loss: 0.5338, Test Loss: 0.5346\n",
0213 "Epoch [46/400], Train Loss: 0.5337, Test Loss: 0.5345\n",
0214 "Epoch [47/400], Train Loss: 0.5336, Test Loss: 0.5323\n",
0215 "Epoch [48/400], Train Loss: 0.5333, Test Loss: 0.5298\n",
0216 "Epoch [49/400], Train Loss: 0.5332, Test Loss: 0.5388\n",
0217 "Epoch [50/400], Train Loss: 0.5331, Test Loss: 0.5312\n",
0218 "Epoch [51/400], Train Loss: 0.5329, Test Loss: 0.5305\n",
0219 "Epoch [52/400], Train Loss: 0.5328, Test Loss: 0.5325\n",
0220 "Epoch [53/400], Train Loss: 0.5325, Test Loss: 0.5333\n",
0221 "Epoch [54/400], Train Loss: 0.5325, Test Loss: 0.5285\n",
0222 "Epoch [55/400], Train Loss: 0.5325, Test Loss: 0.5400\n",
0223 "Epoch [56/400], Train Loss: 0.5323, Test Loss: 0.5324\n",
0224 "Epoch [57/400], Train Loss: 0.5320, Test Loss: 0.5298\n",
0225 "Epoch [58/400], Train Loss: 0.5319, Test Loss: 0.5408\n",
0226 "Epoch [59/400], Train Loss: 0.5319, Test Loss: 0.5294\n",
0227 "Epoch [60/400], Train Loss: 0.5315, Test Loss: 0.5293\n",
0228 "Epoch [61/400], Train Loss: 0.5316, Test Loss: 0.5381\n",
0229 "Epoch [62/400], Train Loss: 0.5315, Test Loss: 0.5302\n",
0230 "Epoch [63/400], Train Loss: 0.5316, Test Loss: 0.5329\n",
0231 "Epoch [64/400], Train Loss: 0.5313, Test Loss: 0.5341\n",
0232 "Epoch [65/400], Train Loss: 0.5311, Test Loss: 0.5333\n",
0233 "Epoch [66/400], Train Loss: 0.5311, Test Loss: 0.5379\n",
0234 "Epoch [67/400], Train Loss: 0.5310, Test Loss: 0.5314\n",
0235 "Epoch [68/400], Train Loss: 0.5308, Test Loss: 0.5377\n",
0236 "Epoch [69/400], Train Loss: 0.5310, Test Loss: 0.5325\n",
0237 "Epoch [70/400], Train Loss: 0.5307, Test Loss: 0.5307\n",
0238 "Epoch [71/400], Train Loss: 0.5305, Test Loss: 0.5321\n",
0239 "Epoch [72/400], Train Loss: 0.5304, Test Loss: 0.5328\n",
0240 "Epoch [73/400], Train Loss: 0.5304, Test Loss: 0.5355\n",
0241 "Epoch [74/400], Train Loss: 0.5301, Test Loss: 0.5298\n",
0242 "Epoch [75/400], Train Loss: 0.5300, Test Loss: 0.5363\n",
0243 "Epoch [76/400], Train Loss: 0.5303, Test Loss: 0.5331\n",
0244 "Epoch [77/400], Train Loss: 0.5298, Test Loss: 0.5270\n",
0245 "Epoch [78/400], Train Loss: 0.5300, Test Loss: 0.5324\n",
0246 "Epoch [79/400], Train Loss: 0.5300, Test Loss: 0.5336\n",
0247 "Epoch [80/400], Train Loss: 0.5297, Test Loss: 0.5283\n",
0248 "Epoch [81/400], Train Loss: 0.5297, Test Loss: 0.5285\n",
0249 "Epoch [82/400], Train Loss: 0.5295, Test Loss: 0.5286\n",
0250 "Epoch [83/400], Train Loss: 0.5295, Test Loss: 0.5277\n",
0251 "Epoch [84/400], Train Loss: 0.5294, Test Loss: 0.5300\n",
0252 "Epoch [85/400], Train Loss: 0.5295, Test Loss: 0.5317\n",
0253 "Epoch [86/400], Train Loss: 0.5292, Test Loss: 0.5288\n",
0254 "Epoch [87/400], Train Loss: 0.5293, Test Loss: 0.5295\n",
0255 "Epoch [88/400], Train Loss: 0.5291, Test Loss: 0.5273\n",
0256 "Epoch [89/400], Train Loss: 0.5292, Test Loss: 0.5289\n",
0257 "Epoch [90/400], Train Loss: 0.5292, Test Loss: 0.5273\n",
0258 "Epoch [91/400], Train Loss: 0.5289, Test Loss: 0.5370\n",
0259 "Epoch [92/400], Train Loss: 0.5288, Test Loss: 0.5263\n",
0260 "Epoch [93/400], Train Loss: 0.5288, Test Loss: 0.5338\n",
0261 "Epoch [94/400], Train Loss: 0.5287, Test Loss: 0.5326\n",
0262 "Epoch [95/400], Train Loss: 0.5286, Test Loss: 0.5300\n",
0263 "Epoch [96/400], Train Loss: 0.5286, Test Loss: 0.5280\n",
0264 "Epoch [97/400], Train Loss: 0.5284, Test Loss: 0.5291\n",
0265 "Epoch [98/400], Train Loss: 0.5284, Test Loss: 0.5310\n",
0266 "Epoch [99/400], Train Loss: 0.5283, Test Loss: 0.5275\n",
0267 "Epoch [100/400], Train Loss: 0.5280, Test Loss: 0.5294\n",
0268 "Epoch [101/400], Train Loss: 0.5279, Test Loss: 0.5290\n",
0269 "Epoch [102/400], Train Loss: 0.5273, Test Loss: 0.5264\n",
0270 "Epoch [103/400], Train Loss: 0.5270, Test Loss: 0.5289\n",
0271 "Epoch [104/400], Train Loss: 0.5265, Test Loss: 0.5294\n",
0272 "Epoch [105/400], Train Loss: 0.5263, Test Loss: 0.5290\n",
0273 "Epoch [106/400], Train Loss: 0.5261, Test Loss: 0.5302\n",
0274 "Epoch [107/400], Train Loss: 0.5258, Test Loss: 0.5309\n",
0275 "Epoch [108/400], Train Loss: 0.5258, Test Loss: 0.5254\n",
0276 "Epoch [109/400], Train Loss: 0.5256, Test Loss: 0.5234\n",
0277 "Epoch [110/400], Train Loss: 0.5255, Test Loss: 0.5307\n",
0278 "Epoch [111/400], Train Loss: 0.5256, Test Loss: 0.5250\n",
0279 "Epoch [112/400], Train Loss: 0.5252, Test Loss: 0.5300\n",
0280 "Epoch [113/400], Train Loss: 0.5253, Test Loss: 0.5328\n",
0281 "Epoch [114/400], Train Loss: 0.5252, Test Loss: 0.5347\n",
0282 "Epoch [115/400], Train Loss: 0.5251, Test Loss: 0.5263\n",
0283 "Epoch [116/400], Train Loss: 0.5250, Test Loss: 0.5312\n",
0284 "Epoch [117/400], Train Loss: 0.5250, Test Loss: 0.5313\n",
0285 "Epoch [118/400], Train Loss: 0.5248, Test Loss: 0.5291\n",
0286 "Epoch [119/400], Train Loss: 0.5249, Test Loss: 0.5314\n",
0287 "Epoch [120/400], Train Loss: 0.5249, Test Loss: 0.5246\n",
0288 "Epoch [121/400], Train Loss: 0.5246, Test Loss: 0.5271\n",
0289 "Epoch [122/400], Train Loss: 0.5244, Test Loss: 0.5286\n",
0290 "Epoch [123/400], Train Loss: 0.5241, Test Loss: 0.5361\n",
0291 "Epoch [124/400], Train Loss: 0.5240, Test Loss: 0.5229\n",
0292 "Epoch [125/400], Train Loss: 0.5239, Test Loss: 0.5268\n",
0293 "Epoch [126/400], Train Loss: 0.5239, Test Loss: 0.5233\n",
0294 "Epoch [127/400], Train Loss: 0.5238, Test Loss: 0.5254\n",
0295 "Epoch [128/400], Train Loss: 0.5236, Test Loss: 0.5271\n",
0296 "Epoch [129/400], Train Loss: 0.5235, Test Loss: 0.5219\n",
0297 "Epoch [130/400], Train Loss: 0.5234, Test Loss: 0.5273\n",
0298 "Epoch [131/400], Train Loss: 0.5232, Test Loss: 0.5241\n",
0299 "Epoch [132/400], Train Loss: 0.5230, Test Loss: 0.5234\n",
0300 "Epoch [133/400], Train Loss: 0.5229, Test Loss: 0.5232\n",
0301 "Epoch [134/400], Train Loss: 0.5229, Test Loss: 0.5288\n",
0302 "Epoch [135/400], Train Loss: 0.5229, Test Loss: 0.5261\n",
0303 "Epoch [136/400], Train Loss: 0.5230, Test Loss: 0.5271\n",
0304 "Epoch [137/400], Train Loss: 0.5227, Test Loss: 0.5287\n",
0305 "Epoch [138/400], Train Loss: 0.5228, Test Loss: 0.5216\n",
0306 "Epoch [139/400], Train Loss: 0.5227, Test Loss: 0.5263\n",
0307 "Epoch [140/400], Train Loss: 0.5224, Test Loss: 0.5274\n",
0308 "Epoch [141/400], Train Loss: 0.5225, Test Loss: 0.5234\n",
0309 "Epoch [142/400], Train Loss: 0.5226, Test Loss: 0.5251\n",
0310 "Epoch [143/400], Train Loss: 0.5221, Test Loss: 0.5224\n",
0311 "Epoch [144/400], Train Loss: 0.5223, Test Loss: 0.5222\n",
0312 "Epoch [145/400], Train Loss: 0.5224, Test Loss: 0.5275\n",
0313 "Epoch [146/400], Train Loss: 0.5223, Test Loss: 0.5203\n",
0314 "Epoch [147/400], Train Loss: 0.5223, Test Loss: 0.5218\n",
0315 "Epoch [148/400], Train Loss: 0.5222, Test Loss: 0.5256\n",
0316 "Epoch [149/400], Train Loss: 0.5221, Test Loss: 0.5227\n",
0317 "Epoch [150/400], Train Loss: 0.5219, Test Loss: 0.5210\n",
0318 "Epoch [151/400], Train Loss: 0.5221, Test Loss: 0.5239\n",
0319 "Epoch [152/400], Train Loss: 0.5221, Test Loss: 0.5218\n",
0320 "Epoch [153/400], Train Loss: 0.5219, Test Loss: 0.5305\n",
0321 "Epoch [154/400], Train Loss: 0.5219, Test Loss: 0.5248\n",
0322 "Epoch [155/400], Train Loss: 0.5218, Test Loss: 0.5247\n",
0323 "Epoch [156/400], Train Loss: 0.5218, Test Loss: 0.5222\n",
0324 "Epoch [157/400], Train Loss: 0.5216, Test Loss: 0.5332\n",
0325 "Epoch [158/400], Train Loss: 0.5217, Test Loss: 0.5230\n",
0326 "Epoch [159/400], Train Loss: 0.5217, Test Loss: 0.5237\n",
0327 "Epoch [160/400], Train Loss: 0.5216, Test Loss: 0.5205\n",
0328 "Epoch [161/400], Train Loss: 0.5215, Test Loss: 0.5208\n",
0329 "Epoch [162/400], Train Loss: 0.5216, Test Loss: 0.5242\n",
0330 "Epoch [163/400], Train Loss: 0.5216, Test Loss: 0.5254\n",
0331 "Epoch [164/400], Train Loss: 0.5214, Test Loss: 0.5229\n",
0332 "Epoch [165/400], Train Loss: 0.5214, Test Loss: 0.5260\n",
0333 "Epoch [166/400], Train Loss: 0.5213, Test Loss: 0.5193\n",
0334 "Epoch [167/400], Train Loss: 0.5212, Test Loss: 0.5225\n",
0335 "Epoch [168/400], Train Loss: 0.5211, Test Loss: 0.5240\n",
0336 "Epoch [169/400], Train Loss: 0.5213, Test Loss: 0.5220\n",
0337 "Epoch [170/400], Train Loss: 0.5213, Test Loss: 0.5276\n",
0338 "Epoch [171/400], Train Loss: 0.5211, Test Loss: 0.5203\n",
0339 "Epoch [172/400], Train Loss: 0.5214, Test Loss: 0.5202\n",
0340 "Epoch [173/400], Train Loss: 0.5210, Test Loss: 0.5213\n",
0341 "Epoch [174/400], Train Loss: 0.5212, Test Loss: 0.5215\n",
0342 "Epoch [175/400], Train Loss: 0.5211, Test Loss: 0.5242\n",
0343 "Epoch [176/400], Train Loss: 0.5210, Test Loss: 0.5217\n",
0344 "Epoch [177/400], Train Loss: 0.5209, Test Loss: 0.5231\n",
0345 "Epoch [178/400], Train Loss: 0.5210, Test Loss: 0.5225\n",
0346 "Epoch [179/400], Train Loss: 0.5210, Test Loss: 0.5229\n",
0347 "Epoch [180/400], Train Loss: 0.5208, Test Loss: 0.5235\n",
0348 "Epoch [181/400], Train Loss: 0.5207, Test Loss: 0.5244\n",
0349 "Epoch [182/400], Train Loss: 0.5208, Test Loss: 0.5224\n",
0350 "Epoch [183/400], Train Loss: 0.5209, Test Loss: 0.5264\n",
0351 "Epoch [184/400], Train Loss: 0.5207, Test Loss: 0.5220\n",
0352 "Epoch [185/400], Train Loss: 0.5206, Test Loss: 0.5202\n",
0353 "Epoch [186/400], Train Loss: 0.5208, Test Loss: 0.5187\n",
0354 "Epoch [187/400], Train Loss: 0.5206, Test Loss: 0.5270\n",
0355 "Epoch [188/400], Train Loss: 0.5207, Test Loss: 0.5196\n",
0356 "Epoch [189/400], Train Loss: 0.5205, Test Loss: 0.5270\n",
0357 "Epoch [190/400], Train Loss: 0.5207, Test Loss: 0.5241\n",
0358 "Epoch [191/400], Train Loss: 0.5206, Test Loss: 0.5226\n",
0359 "Epoch [192/400], Train Loss: 0.5205, Test Loss: 0.5289\n",
0360 "Epoch [193/400], Train Loss: 0.5205, Test Loss: 0.5204\n",
0361 "Epoch [194/400], Train Loss: 0.5204, Test Loss: 0.5215\n",
0362 "Epoch [195/400], Train Loss: 0.5205, Test Loss: 0.5205\n",
0363 "Epoch [196/400], Train Loss: 0.5204, Test Loss: 0.5236\n",
0364 "Epoch [197/400], Train Loss: 0.5205, Test Loss: 0.5209\n",
0365 "Epoch [198/400], Train Loss: 0.5202, Test Loss: 0.5225\n",
0366 "Epoch [199/400], Train Loss: 0.5204, Test Loss: 0.5219\n",
0367 "Epoch [200/400], Train Loss: 0.5203, Test Loss: 0.5217\n",
0368 "Epoch [201/400], Train Loss: 0.5204, Test Loss: 0.5237\n",
0369 "Epoch [202/400], Train Loss: 0.5201, Test Loss: 0.5186\n",
0370 "Epoch [203/400], Train Loss: 0.5203, Test Loss: 0.5228\n",
0371 "Epoch [204/400], Train Loss: 0.5202, Test Loss: 0.5213\n",
0372 "Epoch [205/400], Train Loss: 0.5200, Test Loss: 0.5197\n",
0373 "Epoch [206/400], Train Loss: 0.5202, Test Loss: 0.5209\n",
0374 "Epoch [207/400], Train Loss: 0.5200, Test Loss: 0.5250\n",
0375 "Epoch [208/400], Train Loss: 0.5203, Test Loss: 0.5183\n",
0376 "Epoch [209/400], Train Loss: 0.5201, Test Loss: 0.5181\n",
0377 "Epoch [210/400], Train Loss: 0.5200, Test Loss: 0.5235\n",
0378 "Epoch [211/400], Train Loss: 0.5201, Test Loss: 0.5209\n",
0379 "Epoch [212/400], Train Loss: 0.5200, Test Loss: 0.5203\n",
0380 "Epoch [213/400], Train Loss: 0.5202, Test Loss: 0.5235\n",
0381 "Epoch [214/400], Train Loss: 0.5201, Test Loss: 0.5184\n",
0382 "Epoch [215/400], Train Loss: 0.5199, Test Loss: 0.5275\n",
0383 "Epoch [216/400], Train Loss: 0.5199, Test Loss: 0.5200\n",
0384 "Epoch [217/400], Train Loss: 0.5199, Test Loss: 0.5216\n",
0385 "Epoch [218/400], Train Loss: 0.5199, Test Loss: 0.5230\n",
0386 "Epoch [219/400], Train Loss: 0.5200, Test Loss: 0.5193\n",
0387 "Epoch [220/400], Train Loss: 0.5199, Test Loss: 0.5217\n",
0388 "Epoch [221/400], Train Loss: 0.5200, Test Loss: 0.5234\n",
0389 "Epoch [222/400], Train Loss: 0.5197, Test Loss: 0.5226\n",
0390 "Epoch [223/400], Train Loss: 0.5198, Test Loss: 0.5242\n",
0391 "Epoch [224/400], Train Loss: 0.5198, Test Loss: 0.5226\n",
0392 "Epoch [225/400], Train Loss: 0.5199, Test Loss: 0.5172\n",
0393 "Epoch [226/400], Train Loss: 0.5197, Test Loss: 0.5206\n",
0394 "Epoch [227/400], Train Loss: 0.5197, Test Loss: 0.5211\n",
0395 "Epoch [228/400], Train Loss: 0.5197, Test Loss: 0.5199\n",
0396 "Epoch [229/400], Train Loss: 0.5197, Test Loss: 0.5194\n",
0397 "Epoch [230/400], Train Loss: 0.5197, Test Loss: 0.5212\n",
0398 "Epoch [231/400], Train Loss: 0.5197, Test Loss: 0.5235\n",
0399 "Epoch [232/400], Train Loss: 0.5199, Test Loss: 0.5180\n",
0400 "Epoch [233/400], Train Loss: 0.5197, Test Loss: 0.5186\n",
0401 "Epoch [234/400], Train Loss: 0.5198, Test Loss: 0.5192\n",
0402 "Epoch [235/400], Train Loss: 0.5197, Test Loss: 0.5232\n",
0403 "Epoch [236/400], Train Loss: 0.5195, Test Loss: 0.5267\n",
0404 "Epoch [237/400], Train Loss: 0.5198, Test Loss: 0.5203\n",
0405 "Epoch [238/400], Train Loss: 0.5196, Test Loss: 0.5195\n",
0406 "Epoch [239/400], Train Loss: 0.5197, Test Loss: 0.5210\n",
0407 "Epoch [240/400], Train Loss: 0.5196, Test Loss: 0.5273\n",
0408 "Epoch [241/400], Train Loss: 0.5195, Test Loss: 0.5247\n",
0409 "Epoch [242/400], Train Loss: 0.5197, Test Loss: 0.5184\n",
0410 "Epoch [243/400], Train Loss: 0.5196, Test Loss: 0.5188\n",
0411 "Epoch [244/400], Train Loss: 0.5198, Test Loss: 0.5201\n",
0412 "Epoch [245/400], Train Loss: 0.5195, Test Loss: 0.5242\n",
0413 "Epoch [246/400], Train Loss: 0.5196, Test Loss: 0.5204\n",
0414 "Epoch [247/400], Train Loss: 0.5196, Test Loss: 0.5232\n",
0415 "Epoch [248/400], Train Loss: 0.5194, Test Loss: 0.5268\n",
0416 "Epoch [249/400], Train Loss: 0.5196, Test Loss: 0.5205\n",
0417 "Epoch [250/400], Train Loss: 0.5195, Test Loss: 0.5255\n",
0418 "Epoch [251/400], Train Loss: 0.5195, Test Loss: 0.5211\n",
0419 "Epoch [252/400], Train Loss: 0.5195, Test Loss: 0.5200\n",
0420 "Epoch [253/400], Train Loss: 0.5196, Test Loss: 0.5217\n",
0421 "Epoch [254/400], Train Loss: 0.5196, Test Loss: 0.5208\n",
0422 "Epoch [255/400], Train Loss: 0.5194, Test Loss: 0.5209\n",
0423 "Epoch [256/400], Train Loss: 0.5195, Test Loss: 0.5252\n",
0424 "Epoch [257/400], Train Loss: 0.5194, Test Loss: 0.5209\n",
0425 "Epoch [258/400], Train Loss: 0.5195, Test Loss: 0.5247\n",
0426 "Epoch [259/400], Train Loss: 0.5196, Test Loss: 0.5201\n",
0427 "Epoch [260/400], Train Loss: 0.5192, Test Loss: 0.5191\n",
0428 "Epoch [261/400], Train Loss: 0.5194, Test Loss: 0.5202\n",
0429 "Epoch [262/400], Train Loss: 0.5194, Test Loss: 0.5209\n",
0430 "Epoch [263/400], Train Loss: 0.5193, Test Loss: 0.5258\n",
0431 "Epoch [264/400], Train Loss: 0.5194, Test Loss: 0.5226\n",
0432 "Epoch [265/400], Train Loss: 0.5195, Test Loss: 0.5263\n",
0433 "Epoch [266/400], Train Loss: 0.5195, Test Loss: 0.5223\n",
0434 "Epoch [267/400], Train Loss: 0.5193, Test Loss: 0.5236\n",
0435 "Epoch [268/400], Train Loss: 0.5192, Test Loss: 0.5274\n",
0436 "Epoch [269/400], Train Loss: 0.5194, Test Loss: 0.5193\n",
0437 "Epoch [270/400], Train Loss: 0.5191, Test Loss: 0.5189\n",
0438 "Epoch [271/400], Train Loss: 0.5193, Test Loss: 0.5257\n",
0439 "Epoch [272/400], Train Loss: 0.5193, Test Loss: 0.5191\n",
0440 "Epoch [273/400], Train Loss: 0.5194, Test Loss: 0.5192\n",
0441 "Epoch [274/400], Train Loss: 0.5193, Test Loss: 0.5215\n",
0442 "Epoch [275/400], Train Loss: 0.5192, Test Loss: 0.5199\n",
0443 "Epoch [276/400], Train Loss: 0.5193, Test Loss: 0.5231\n",
0444 "Epoch [277/400], Train Loss: 0.5192, Test Loss: 0.5210\n",
0445 "Epoch [278/400], Train Loss: 0.5192, Test Loss: 0.5203\n",
0446 "Epoch [279/400], Train Loss: 0.5194, Test Loss: 0.5225\n",
0447 "Epoch [280/400], Train Loss: 0.5193, Test Loss: 0.5242\n",
0448 "Epoch [281/400], Train Loss: 0.5192, Test Loss: 0.5270\n",
0449 "Epoch [282/400], Train Loss: 0.5192, Test Loss: 0.5226\n",
0450 "Epoch [283/400], Train Loss: 0.5192, Test Loss: 0.5221\n",
0451 "Epoch [284/400], Train Loss: 0.5193, Test Loss: 0.5171\n",
0452 "Epoch [285/400], Train Loss: 0.5191, Test Loss: 0.5211\n",
0453 "Epoch [286/400], Train Loss: 0.5191, Test Loss: 0.5178\n",
0454 "Epoch [287/400], Train Loss: 0.5190, Test Loss: 0.5173\n",
0455 "Epoch [288/400], Train Loss: 0.5192, Test Loss: 0.5277\n",
0456 "Epoch [289/400], Train Loss: 0.5190, Test Loss: 0.5196\n",
0457 "Epoch [290/400], Train Loss: 0.5192, Test Loss: 0.5200\n",
0458 "Epoch [291/400], Train Loss: 0.5190, Test Loss: 0.5186\n",
0459 "Epoch [292/400], Train Loss: 0.5192, Test Loss: 0.5211\n",
0460 "Epoch [293/400], Train Loss: 0.5192, Test Loss: 0.5249\n",
0461 "Epoch [294/400], Train Loss: 0.5191, Test Loss: 0.5196\n",
0462 "Epoch [295/400], Train Loss: 0.5191, Test Loss: 0.5215\n",
0463 "Epoch [296/400], Train Loss: 0.5192, Test Loss: 0.5223\n",
0464 "Epoch [297/400], Train Loss: 0.5192, Test Loss: 0.5233\n",
0465 "Epoch [298/400], Train Loss: 0.5191, Test Loss: 0.5223\n",
0466 "Epoch [299/400], Train Loss: 0.5189, Test Loss: 0.5212\n",
0467 "Epoch [300/400], Train Loss: 0.5189, Test Loss: 0.5199\n",
0468 "Epoch [301/400], Train Loss: 0.5190, Test Loss: 0.5197\n",
0469 "Epoch [302/400], Train Loss: 0.5190, Test Loss: 0.5289\n",
0470 "Epoch [303/400], Train Loss: 0.5189, Test Loss: 0.5220\n",
0471 "Epoch [304/400], Train Loss: 0.5190, Test Loss: 0.5296\n",
0472 "Epoch [305/400], Train Loss: 0.5189, Test Loss: 0.5185\n",
0473 "Epoch [306/400], Train Loss: 0.5190, Test Loss: 0.5191\n",
0474 "Epoch [307/400], Train Loss: 0.5189, Test Loss: 0.5196\n",
0475 "Epoch [308/400], Train Loss: 0.5190, Test Loss: 0.5204\n",
0476 "Epoch [309/400], Train Loss: 0.5188, Test Loss: 0.5200\n",
0477 "Epoch [310/400], Train Loss: 0.5190, Test Loss: 0.5254\n",
0478 "Epoch [311/400], Train Loss: 0.5190, Test Loss: 0.5213\n",
0479 "Epoch [312/400], Train Loss: 0.5189, Test Loss: 0.5206\n",
0480 "Epoch [313/400], Train Loss: 0.5188, Test Loss: 0.5239\n",
0481 "Epoch [314/400], Train Loss: 0.5189, Test Loss: 0.5198\n",
0482 "Epoch [315/400], Train Loss: 0.5190, Test Loss: 0.5172\n",
0483 "Epoch [316/400], Train Loss: 0.5187, Test Loss: 0.5199\n",
0484 "Epoch [317/400], Train Loss: 0.5189, Test Loss: 0.5190\n",
0485 "Epoch [318/400], Train Loss: 0.5188, Test Loss: 0.5195\n",
0486 "Epoch [319/400], Train Loss: 0.5187, Test Loss: 0.5191\n",
0487 "Epoch [320/400], Train Loss: 0.5188, Test Loss: 0.5213\n",
0488 "Epoch [321/400], Train Loss: 0.5190, Test Loss: 0.5191\n",
0489 "Epoch [322/400], Train Loss: 0.5188, Test Loss: 0.5250\n",
0490 "Epoch [323/400], Train Loss: 0.5187, Test Loss: 0.5187\n",
0491 "Epoch [324/400], Train Loss: 0.5188, Test Loss: 0.5219\n",
0492 "Epoch [325/400], Train Loss: 0.5186, Test Loss: 0.5200\n",
0493 "Epoch [326/400], Train Loss: 0.5187, Test Loss: 0.5229\n",
0494 "Epoch [327/400], Train Loss: 0.5187, Test Loss: 0.5191\n",
0495 "Epoch [328/400], Train Loss: 0.5187, Test Loss: 0.5193\n",
0496 "Epoch [329/400], Train Loss: 0.5188, Test Loss: 0.5193\n",
0497 "Epoch [330/400], Train Loss: 0.5187, Test Loss: 0.5209\n",
0498 "Epoch [331/400], Train Loss: 0.5187, Test Loss: 0.5222\n",
0499 "Epoch [332/400], Train Loss: 0.5187, Test Loss: 0.5207\n",
0500 "Epoch [333/400], Train Loss: 0.5187, Test Loss: 0.5180\n",
0501 "Epoch [334/400], Train Loss: 0.5186, Test Loss: 0.5229\n",
0502 "Epoch [335/400], Train Loss: 0.5186, Test Loss: 0.5184\n",
0503 "Epoch [336/400], Train Loss: 0.5187, Test Loss: 0.5188\n",
0504 "Epoch [337/400], Train Loss: 0.5187, Test Loss: 0.5204\n",
0505 "Epoch [338/400], Train Loss: 0.5185, Test Loss: 0.5292\n",
0506 "Epoch [339/400], Train Loss: 0.5186, Test Loss: 0.5214\n",
0507 "Epoch [340/400], Train Loss: 0.5187, Test Loss: 0.5210\n",
0508 "Epoch [341/400], Train Loss: 0.5187, Test Loss: 0.5220\n",
0509 "Epoch [342/400], Train Loss: 0.5186, Test Loss: 0.5202\n",
0510 "Epoch [343/400], Train Loss: 0.5185, Test Loss: 0.5311\n",
0511 "Epoch [344/400], Train Loss: 0.5186, Test Loss: 0.5209\n",
0512 "Epoch [345/400], Train Loss: 0.5187, Test Loss: 0.5205\n",
0513 "Epoch [346/400], Train Loss: 0.5185, Test Loss: 0.5178\n",
0514 "Epoch [347/400], Train Loss: 0.5185, Test Loss: 0.5186\n",
0515 "Epoch [348/400], Train Loss: 0.5186, Test Loss: 0.5228\n",
0516 "Epoch [349/400], Train Loss: 0.5184, Test Loss: 0.5222\n",
0517 "Epoch [350/400], Train Loss: 0.5186, Test Loss: 0.5246\n",
0518 "Epoch [351/400], Train Loss: 0.5185, Test Loss: 0.5195\n",
0519 "Epoch [352/400], Train Loss: 0.5185, Test Loss: 0.5189\n",
0520 "Epoch [353/400], Train Loss: 0.5184, Test Loss: 0.5212\n",
0521 "Epoch [354/400], Train Loss: 0.5186, Test Loss: 0.5267\n",
0522 "Epoch [355/400], Train Loss: 0.5184, Test Loss: 0.5227\n",
0523 "Epoch [356/400], Train Loss: 0.5183, Test Loss: 0.5213\n",
0524 "Epoch [357/400], Train Loss: 0.5183, Test Loss: 0.5214\n",
0525 "Epoch [358/400], Train Loss: 0.5183, Test Loss: 0.5228\n",
0526 "Epoch [359/400], Train Loss: 0.5185, Test Loss: 0.5195\n",
0527 "Epoch [360/400], Train Loss: 0.5183, Test Loss: 0.5234\n",
0528 "Epoch [361/400], Train Loss: 0.5185, Test Loss: 0.5198\n",
0529 "Epoch [362/400], Train Loss: 0.5183, Test Loss: 0.5193\n",
0530 "Epoch [363/400], Train Loss: 0.5185, Test Loss: 0.5208\n",
0531 "Epoch [364/400], Train Loss: 0.5185, Test Loss: 0.5217\n",
0532 "Epoch [365/400], Train Loss: 0.5184, Test Loss: 0.5258\n",
0533 "Epoch [366/400], Train Loss: 0.5185, Test Loss: 0.5199\n",
0534 "Epoch [367/400], Train Loss: 0.5183, Test Loss: 0.5197\n",
0535 "Epoch [368/400], Train Loss: 0.5182, Test Loss: 0.5189\n",
0536 "Epoch [369/400], Train Loss: 0.5184, Test Loss: 0.5191\n",
0537 "Epoch [370/400], Train Loss: 0.5183, Test Loss: 0.5198\n",
0538 "Epoch [371/400], Train Loss: 0.5182, Test Loss: 0.5230\n",
0539 "Epoch [372/400], Train Loss: 0.5183, Test Loss: 0.5178\n",
0540 "Epoch [373/400], Train Loss: 0.5183, Test Loss: 0.5184\n",
0541 "Epoch [374/400], Train Loss: 0.5182, Test Loss: 0.5187\n",
0542 "Epoch [375/400], Train Loss: 0.5183, Test Loss: 0.5224\n",
0543 "Epoch [376/400], Train Loss: 0.5184, Test Loss: 0.5203\n",
0544 "Epoch [377/400], Train Loss: 0.5183, Test Loss: 0.5197\n",
0545 "Epoch [378/400], Train Loss: 0.5182, Test Loss: 0.5196\n",
0546 "Epoch [379/400], Train Loss: 0.5181, Test Loss: 0.5200\n",
0547 "Epoch [380/400], Train Loss: 0.5182, Test Loss: 0.5264\n",
0548 "Epoch [381/400], Train Loss: 0.5181, Test Loss: 0.5202\n",
0549 "Epoch [382/400], Train Loss: 0.5182, Test Loss: 0.5219\n",
0550 "Epoch [383/400], Train Loss: 0.5183, Test Loss: 0.5219\n",
0551 "Epoch [384/400], Train Loss: 0.5183, Test Loss: 0.5196\n",
0552 "Epoch [385/400], Train Loss: 0.5181, Test Loss: 0.5192\n",
0553 "Epoch [386/400], Train Loss: 0.5180, Test Loss: 0.5206\n",
0554 "Epoch [387/400], Train Loss: 0.5181, Test Loss: 0.5264\n",
0555 "Epoch [388/400], Train Loss: 0.5181, Test Loss: 0.5191\n",
0556 "Epoch [389/400], Train Loss: 0.5183, Test Loss: 0.5240\n",
0557 "Epoch [390/400], Train Loss: 0.5182, Test Loss: 0.5196\n",
0558 "Epoch [391/400], Train Loss: 0.5182, Test Loss: 0.5176\n",
0559 "Epoch [392/400], Train Loss: 0.5181, Test Loss: 0.5219\n",
0560 "Epoch [393/400], Train Loss: 0.5181, Test Loss: 0.5198\n",
0561 "Epoch [394/400], Train Loss: 0.5180, Test Loss: 0.5251\n",
0562 "Epoch [395/400], Train Loss: 0.5181, Test Loss: 0.5182\n",
0563 "Epoch [396/400], Train Loss: 0.5182, Test Loss: 0.5212\n",
0564 "Epoch [397/400], Train Loss: 0.5180, Test Loss: 0.5217\n",
0565 "Epoch [398/400], Train Loss: 0.5181, Test Loss: 0.5213\n",
0566 "Epoch [399/400], Train Loss: 0.5182, Test Loss: 0.5180\n",
0567 "Epoch [400/400], Train Loss: 0.5180, Test Loss: 0.5205\n"
0568 ]
0569 }
0570 ],
0571 "source": [
0572 "import torch\n",
0573 "import torch.nn as nn\n",
0574 "from torch.utils.data import TensorDataset, random_split, DataLoader\n",
0575 "from torch.optim import Adam\n",
0576 "\n",
0577 "# Set device\n",
0578 "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
0579 "print(f\"Using device: {device}\")\n",
0580 "\n",
0581 "# Create multi-class labels\n",
0582 "def create_multiclass_labels(t3_isFake, t3_sim_vxy, displacement_threshold=0.1):\n",
0583 " num_samples = len(t3_isFake)\n",
0584 " labels = torch.zeros((num_samples, 3))\n",
0585 " \n",
0586 " # Fake tracks (class 0)\n",
0587 " fake_mask = t3_isFake\n",
0588 " labels[fake_mask, 0] = 1\n",
0589 " \n",
0590 " # Real tracks\n",
0591 " real_mask = ~fake_mask\n",
0592 " \n",
0593 " # Split real tracks into prompt (class 1) and displaced (class 2)\n",
0594 " prompt_mask = (t3_sim_vxy <= displacement_threshold) & real_mask\n",
0595 " displaced_mask = (t3_sim_vxy > displacement_threshold) & real_mask\n",
0596 " \n",
0597 " labels[prompt_mask, 1] = 1\n",
0598 " labels[displaced_mask, 2] = 1\n",
0599 " \n",
0600 " return labels\n",
0601 "\n",
0602 "# Create labels tensor\n",
0603 "labels_tensor = create_multiclass_labels(\n",
0604 " t3_isFake_filtered,\n",
0605 " t3_sim_vxy_filtered\n",
0606 ")\n",
0607 "\n",
0608 "# Neural network for multi-class classification\n",
0609 "class MultiClassNeuralNetwork(nn.Module):\n",
0610 " def __init__(self):\n",
0611 " super(MultiClassNeuralNetwork, self).__init__()\n",
0612 " self.layer1 = nn.Linear(input_features_numpy.shape[1], 32)\n",
0613 " self.layer2 = nn.Linear(32, 32)\n",
0614 " self.output_layer = nn.Linear(32, 3)\n",
0615 " \n",
0616 " def forward(self, x):\n",
0617 " x = self.layer1(x)\n",
0618 " x = nn.ReLU()(x)\n",
0619 " x = self.layer2(x)\n",
0620 " x = nn.ReLU()(x)\n",
0621 " x = self.output_layer(x)\n",
0622 " return nn.functional.softmax(x, dim=1)\n",
0623 "\n",
0624 "# Weighted loss function for multi-class\n",
0625 "class WeightedCrossEntropyLoss(nn.Module):\n",
0626 " def __init__(self):\n",
0627 " super(WeightedCrossEntropyLoss, self).__init__()\n",
0628 " \n",
0629 " def forward(self, outputs, targets, weights):\n",
0630 " eps = 1e-7\n",
0631 " log_probs = torch.log(outputs + eps)\n",
0632 " losses = -weights * torch.sum(targets * log_probs, dim=1)\n",
0633 " return losses.mean()\n",
0634 "\n",
0635 "# Calculate class weights (each sample gets a weight to equalize class contributions)\n",
0636 "def calculate_class_weights(labels):\n",
0637 " class_counts = torch.sum(labels, dim=0)\n",
0638 " total_samples = len(labels)\n",
0639 " class_weights = total_samples / (3 * class_counts) # Normalize across 3 classes\n",
0640 " \n",
0641 " sample_weights = torch.zeros(len(labels))\n",
0642 " for i in range(3):\n",
0643 " sample_weights[labels[:, i] == 1] = class_weights[i]\n",
0644 " \n",
0645 " return sample_weights\n",
0646 "\n",
0647 "# Print initial dataset size\n",
0648 "print(f\"Initial dataset size: {len(labels_tensor)}\")\n",
0649 "\n",
0650 "# Remove rows with NaN and update everything accordingly\n",
0651 "nan_mask = torch.isnan(input_features_tensor).any(dim=1)\n",
0652 "filtered_inputs = input_features_tensor[~nan_mask]\n",
0653 "filtered_labels = labels_tensor[~nan_mask]\n",
0654 "\n",
0655 "# Print class distribution before downsampling\n",
0656 "class_counts_before = torch.sum(filtered_labels, dim=0)\n",
0657 "print(f\"Class distribution before downsampling - Fake: {class_counts_before[0]}, Prompt: {class_counts_before[1]}, Displaced: {class_counts_before[2]}\")\n",
0658 "\n",
0659 "# Option to downsample each class\n",
0660 "downsample_classes = True # Set to False to disable downsampling\n",
0661 "if downsample_classes:\n",
0662 " # Define downsampling ratios for each class:\n",
0663 " # For example, downsample fakes (class 0) to 20% and keep prompt (class 1) and displaced (class 2) at 100%\n",
0664 " downsample_ratios = {0: 0.2, 1: 1.0, 2: 1.0}\n",
0665 " indices_list = []\n",
0666 " for cls in range(3):\n",
0667 " # Find indices for the current class\n",
0668 " cls_mask = (filtered_labels[:, cls] == 1)\n",
0669 " cls_indices = torch.nonzero(cls_mask).squeeze()\n",
0670 " ratio = downsample_ratios.get(cls, 1.0)\n",
0671 " num_cls = cls_indices.numel()\n",
0672 " num_to_sample = int(num_cls * ratio)\n",
0673 " # Ensure at least one sample is kept if available\n",
0674 " if num_to_sample < 1 and num_cls > 0:\n",
0675 " num_to_sample = 1\n",
0676 " # Shuffle and select the desired number of samples\n",
0677 " cls_indices_shuffled = cls_indices[torch.randperm(num_cls)]\n",
0678 " sampled_cls_indices = cls_indices_shuffled[:num_to_sample]\n",
0679 " indices_list.append(sampled_cls_indices)\n",
0680 " \n",
0681 " # Combine the indices from all classes\n",
0682 " selected_indices = torch.cat(indices_list)\n",
0683 " filtered_inputs = filtered_inputs[selected_indices]\n",
0684 " filtered_labels = filtered_labels[selected_indices]\n",
0685 "\n",
0686 "# Print class distribution after downsampling\n",
0687 "class_counts_after = torch.sum(filtered_labels, dim=0)\n",
0688 "print(f\"Class distribution after downsampling - Fake: {class_counts_after[0]}, Prompt: {class_counts_after[1]}, Displaced: {class_counts_after[2]}\")\n",
0689 "\n",
0690 "# Recalculate sample weights after downsampling (equal weighting per class based on new counts)\n",
0691 "sample_weights = calculate_class_weights(filtered_labels)\n",
0692 "filtered_weights = sample_weights\n",
0693 "\n",
0694 "# Create dataset with weights\n",
0695 "dataset = TensorDataset(filtered_inputs, filtered_labels, filtered_weights)\n",
0696 "\n",
0697 "# Split into train and test sets\n",
0698 "train_size = int(0.8 * len(dataset))\n",
0699 "test_size = len(dataset) - train_size\n",
0700 "train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n",
0701 "\n",
0702 "# Create data loaders\n",
0703 "train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=10, pin_memory=True)\n",
0704 "test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=10, pin_memory=True)\n",
0705 "\n",
0706 "# Initialize model and optimizer\n",
0707 "model = MultiClassNeuralNetwork().to(device)\n",
0708 "loss_function = WeightedCrossEntropyLoss()\n",
0709 "optimizer = Adam(model.parameters(), lr=0.0025)\n",
0710 "\n",
0711 "def evaluate_loss(loader):\n",
0712 " model.eval()\n",
0713 " total_loss = 0\n",
0714 " num_batches = 0\n",
0715 " with torch.no_grad():\n",
0716 " for inputs, targets, weights in loader:\n",
0717 " inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)\n",
0718 " outputs = model(inputs)\n",
0719 " loss = loss_function(outputs, targets, weights)\n",
0720 " total_loss += loss.item()\n",
0721 " num_batches += 1\n",
0722 " return total_loss / num_batches\n",
0723 "\n",
0724 "# Training loop\n",
0725 "num_epochs = 400\n",
0726 "train_loss_log = []\n",
0727 "test_loss_log = []\n",
0728 "\n",
0729 "for epoch in range(num_epochs):\n",
0730 " model.train()\n",
0731 " epoch_loss = 0\n",
0732 " num_batches = 0\n",
0733 " \n",
0734 " for inputs, targets, weights in train_loader:\n",
0735 " inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)\n",
0736 " \n",
0737 " # Forward pass\n",
0738 " outputs = model(inputs)\n",
0739 " loss = loss_function(outputs, targets, weights)\n",
0740 " epoch_loss += loss.item()\n",
0741 " num_batches += 1\n",
0742 " \n",
0743 " # Backward and optimize\n",
0744 " optimizer.zero_grad()\n",
0745 " loss.backward()\n",
0746 " optimizer.step()\n",
0747 " \n",
0748 " # Calculate average losses\n",
0749 " train_loss = epoch_loss / num_batches\n",
0750 " test_loss = evaluate_loss(test_loader)\n",
0751 " \n",
0752 " train_loss_log.append(train_loss)\n",
0753 " test_loss_log.append(test_loss)\n",
0754 " \n",
0755 " print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')"
0756 ]
0757 },
0758 {
0759 "cell_type": "code",
0760 "execution_count": 6,
0761 "metadata": {},
0762 "outputs": [
0763 {
0764 "name": "stdout",
0765 "output_type": "stream",
0766 "text": [
0767 "Baseline accuracy: 0.8611\n",
0768 "\n",
0769 "Feature importances:\n",
0770 "Feature 0 importance: 0.0541\n",
0771 "Feature 2 importance: 0.0480\n",
0772 "Feature 5 importance: 0.0434\n",
0773 "Feature 7 importance: 0.0242\n",
0774 "Feature 6 importance: 0.0223\n",
0775 "Feature 3 importance: 0.0206\n",
0776 "Feature 11 importance: 0.0167\n",
0777 "Feature 10 importance: 0.0148\n",
0778 "Feature 13 importance: 0.0140\n",
0779 "Feature 12 importance: 0.0128\n",
0780 "Feature 9 importance: 0.0114\n",
0781 "Feature 8 importance: 0.0046\n",
0782 "Feature 4 importance: 0.0016\n",
0783 "Feature 1 importance: 0.0000\n"
0784 ]
0785 }
0786 ],
0787 "source": [
0788 "import torch\n",
0789 "import numpy as np\n",
0790 "from sklearn.metrics import accuracy_score\n",
0791 "\n",
0792 "# Convert tensors to numpy for simplicity if you want to manipulate them outside of PyTorch\n",
0793 "input_features_np = input_features_tensor.numpy()\n",
0794 "labels_np = torch.argmax(labels_tensor, dim=1).numpy() # Convert one-hot to class indices\n",
0795 "\n",
0796 "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
0797 "\n",
0798 "def model_accuracy(features, labels, model):\n",
0799 " \"\"\"\n",
0800 " Compute accuracy for a multi-class classification model\n",
0801 " that outputs probabilities of size [batch_size, num_classes].\n",
0802 " \"\"\"\n",
0803 " model.eval() # Set the model to evaluation mode\n",
0804 " \n",
0805 " # Move the features and labels to the correct device\n",
0806 " inputs = features.to(device)\n",
0807 " labels = labels.to(device)\n",
0808 " \n",
0809 " with torch.no_grad():\n",
0810 " outputs = model(inputs) # shape: [batch_size, num_classes]\n",
0811 " # For multi-class, the predicted class is argmax of the probabilities\n",
0812 " predicted = torch.argmax(outputs, dim=1)\n",
0813 " # Convert one-hot encoded labels to class indices if needed\n",
0814 " if len(labels.shape) > 1:\n",
0815 " labels = torch.argmax(labels, dim=1)\n",
0816 " # Compute mean accuracy\n",
0817 " accuracy = (predicted == labels).float().mean().item()\n",
0818 " \n",
0819 " return accuracy\n",
0820 "\n",
0821 "# Compute baseline accuracy\n",
0822 "baseline_accuracy = model_accuracy(input_features_tensor, labels_tensor, model)\n",
0823 "print(f\"Baseline accuracy: {baseline_accuracy:.4f}\")\n",
0824 "\n",
0825 "# Initialize array to store feature importances\n",
0826 "feature_importances = np.zeros(input_features_tensor.shape[1])\n",
0827 "\n",
0828 "# Iterate over each feature for permutation importance\n",
0829 "for i in range(input_features_tensor.shape[1]):\n",
0830 " # Create a copy of the original features\n",
0831 " permuted_features = input_features_tensor.clone()\n",
0832 " \n",
0833 " # Permute feature i across all examples\n",
0834 " # We do this by shuffling the rows for that specific column\n",
0835 " permuted_features[:, i] = permuted_features[torch.randperm(permuted_features.size(0)), i]\n",
0836 " \n",
0837 " # Compute accuracy after permutation\n",
0838 " permuted_accuracy = model_accuracy(permuted_features, labels_tensor, model)\n",
0839 " \n",
0840 " # The drop in accuracy is used as a measure of feature importance\n",
0841 " feature_importances[i] = baseline_accuracy - permuted_accuracy\n",
0842 "\n",
0843 "# Sort features by descending importance\n",
0844 "important_features_indices = np.argsort(feature_importances)[::-1]\n",
0845 "important_features_scores = np.sort(feature_importances)[::-1]\n",
0846 "\n",
0847 "# Print out results\n",
0848 "print(\"\\nFeature importances:\")\n",
0849 "for idx, score in zip(important_features_indices, important_features_scores):\n",
0850 " print(f\"Feature {idx} importance: {score:.4f}\")"
0851 ]
0852 },
0853 {
0854 "cell_type": "code",
0855 "execution_count": 8,
0856 "metadata": {},
0857 "outputs": [
0858 {
0859 "name": "stdout",
0860 "output_type": "stream",
0861 "text": [
0862 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer1[32] = {\n",
0863 "-0.9152892f, 3.2650192f, -0.4164221f, -0.1210157f, -2.4165483f, -1.0984275f, -2.1654966f, -0.8991888f, -0.0503724f, 7.1305695f, -5.2781415f, 3.2997849f, 1.0025330f, -0.5117974f, 0.2957068f, -0.1811045f, -2.7853479f, 1.8040915f, -2.8807588f, -4.6462102f, 1.2869841f, -0.0526987f, 0.4946094f, 2.6554070f, -0.1360572f, 0.2122774f, 4.7361507f, -1.4605266f, 0.1759245f, -0.7966636f, -0.0401897f, -0.2652957f };\n",
0864 "\n",
0865 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer1[14][32] = {\n",
0866 "{ 0.2570587f, 1.5017653f, -1.8436140f, 1.6314303f, -0.1464428f, 1.2261974f, 2.8629315f, -0.0778951f, -0.0007868f, -2.4665442f, 3.7231014f, -0.4062112f, 5.0222125f, -0.4256854f, -0.8145034f, -0.0993065f, 1.1874412f, 3.7737985f, -2.0898068f, 5.0041976f, -0.4184950f, 0.0133298f, -1.1757115f, 0.8953519f, -0.2589224f, 3.4567924f, -1.0867721f, -0.0325336f, -0.1398652f, 5.9361205f, -0.2938714f, 0.0110872f },\n",
0867 "{ 0.0062326f, 0.0294117f, -0.1038531f, -0.1871421f, 0.0092176f, 0.0194613f, -0.0970159f, 0.0044040f, 0.0040717f, -0.0464808f, -0.0161564f, 0.0660082f, 0.0107912f, 0.0041196f, 0.0076985f, -0.0447547f, -0.0234053f, -0.0128952f, -0.0374210f, 0.0277440f, 0.0126392f, 0.0053390f, -0.0176450f, 0.0422363f, -0.1089868f, 0.0229251f, -0.0632515f, -0.0000745f, -0.0120581f, 0.0553841f, 0.1958316f, -0.2002713f },\n",
0868 "{ -0.3374411f, 1.0331864f, -2.3923049f, -1.7857485f, 5.2973804f, -2.4907997f, -5.4545326f, -0.1601160f, -0.0028237f, 0.3510691f, 1.9307067f, 0.1516920f, -9.8718386f, -0.0893254f, -0.3546923f, 0.1746188f, -2.5511057f, -6.9032016f, 1.1837323f, -1.5620790f, 0.3867095f, -0.0118511f, 0.0633005f, -0.4638650f, -0.1623496f, -6.3706584f, -1.3387657f, -0.3824906f, 0.0013318f, -0.2500904f, 0.0434040f, 0.2572088f },\n",
0869 "{ 0.4586262f, -5.7731452f, 2.1999521f, 1.8049341f, -0.2857220f, 0.9761068f, 7.4085426f, 0.6136439f, -0.0216335f, -2.1852305f, -0.8797144f, -0.0105609f, 2.9077356f, 0.6365089f, 0.4854219f, 0.1170259f, 1.9888018f, 1.5127803f, -1.5444536f, 8.0876036f, 0.2033372f, -0.0132441f, -0.8586000f, -1.9558586f, -0.0361535f, 4.7021170f, 1.0431648f, 2.0264521f, -0.2665041f, 6.2334776f, -0.0008584f, -0.0026126f },\n",
0870 "{ -1.5857489f, 2.3446348f, -14.6416082f, 2.6467063f, 3.0982802f, 14.1958466f, 7.0268164f, -0.4456313f, 0.0005447f, -1.3794490f, -10.3501320f, 1.1612288f, 9.6774111f, -0.1875248f, 21.0353413f, -0.1961844f, 8.0234823f, -4.1653094f, -7.1429234f, 15.8104372f, 2.7506628f, -0.0142524f, 22.4932308f, 0.7682357f, 0.0385473f, 13.5405970f, -4.9976201f, -22.4438667f, -0.1693429f, -18.7507858f, -0.0939464f, 0.2192265f },\n",
0871 "{ 65.2134705f, 1.4352289f, -0.6685436f, 22.9858570f, 1.6736605f, -0.1810708f, 0.5204540f, -53.2590408f, -0.0630155f, 3.6024182f, -3.8777969f, 5.0021510f, 0.0055030f, 23.8294449f, -1.3818942f, -0.2419317f, -10.2253504f, -1.8309352f, 1.7169305f, -0.3938941f, 9.1144180f, -0.0004920f, -3.1774588f, -36.4919891f, -0.1711030f, 6.5288949f, 4.5861993f, 0.8314257f, 0.0305954f, 5.1864023f, 0.2658210f, -0.1748345f },\n",
0872 "{ -0.7423377f, -5.8733273f, 5.0070434f, -6.5473375f, 3.8788877f, 6.8001447f, 3.3014266f, -0.5657389f, -0.0376282f, 6.1230965f, 4.0765481f, 0.1596697f, -7.4254904f, 0.2356068f, 3.5560462f, -0.1223621f, 0.7022666f, -9.2908258f, -4.4684458f, -1.1861130f, 0.1879538f, -0.0337607f, 0.9228330f, 1.2672698f, -0.1690857f, -11.2556086f, 4.1028724f, 0.4850706f, 0.0041372f, 3.7200036f, 0.1445599f, 0.2449260f },\n",
0873 "{ -2.3848257f, -1.7144806f, -0.2987937f, 10.1947727f, 0.7855392f, 7.1466241f, -2.2256043f, -1.2184218f, 0.1233135f, 5.3274498f, 2.9086673f, 1.5096599f, 2.8449385f, -0.2345320f, -2.2044909f, 0.1858539f, -0.7592235f, 3.1651254f, 0.5184333f, -3.7233777f, 1.5772278f, 0.0997663f, -0.0325775f, -13.2207737f, -0.0340279f, -6.4953661f, -18.9173355f, 0.9963044f, 0.1927230f, 4.6532283f, -0.0916147f, -0.0406466f },\n",
0874 "{ 5.0975156f, -0.7078310f, 26.7917671f, -1.9278797f, -2.4459264f, -12.1174421f, -5.1347051f, 2.0090365f, -0.0012259f, -11.6201696f, 24.7306499f, -8.7715597f, 4.8136749f, -0.7106145f, -46.1458054f, 0.1771528f, 22.9087524f, -5.5876012f, 6.9944615f, 29.2786064f, -7.2195830f, 0.0270186f, 14.7860146f, 1.0168871f, 0.1467975f, -19.6260185f, 3.3284237f, 22.5860500f, -0.0160821f, -35.1570702f, 0.1473373f, 0.1500054f },\n",
0875 "{ 53.6092033f, -0.7677985f, 3.3910768f, -27.1046524f, -5.1561117f, 1.6982658f, 0.9386115f, -72.5867996f, 57.8674583f, -8.2815771f, 3.2216380f, -28.9387760f, 3.2793560f, 37.3099365f, 2.2979465f, 0.1827718f, 6.8113675f, 1.0104842f, -1.5407079f, -2.5780137f, 32.4788666f, -61.8150520f, -2.6497467f, -10.6412830f, -0.2596186f, -11.8458385f, 27.5528336f, -1.3142428f, -0.2566442f, -17.8737431f, 0.0261727f, 0.0107839f },\n",
0876 "{ 0.2419503f, -3.8581870f, 6.1144238f, 2.3472424f, 1.9470867f, -3.3741038f, 0.4852004f, 0.6366135f, -0.0736884f, 2.2963598f, -0.4113644f, -2.0738223f, -3.7331092f, 0.7157578f, 1.2168316f, -0.1584605f, -2.3843584f, -6.1547771f, -0.4764193f, 4.6278925f, -1.3195806f, -0.0717061f, 2.5889101f, -3.7769084f, 0.0527631f, 3.8808305f, -0.0672685f, 0.3294500f, -0.1916338f, -2.2346771f, -0.1518883f, 0.0462940f },\n",
0877 "{ 0.7453216f, -0.5999010f, -2.6196895f, -6.5323448f, 0.0482639f, 0.0162446f, -0.1185504f, -2.6736362f, -0.0037108f, -4.2818441f, 4.1449761f, 3.3861248f, 0.1735097f, -4.7952204f, 0.8002076f, 0.0598137f, 0.2611136f, 2.4648571f, 1.3178054f, -12.6864462f, -1.4815618f, -0.0113616f, 0.0697572f, 11.0613089f, -0.1636935f, -1.3598750f, 5.4537063f, 0.4077346f, 0.2614925f, 0.4335817f, -0.1616396f, 0.0372773f },\n",
0878 "{ 0.3573107f, -0.1230106f, -0.0517133f, -0.7743639f, 0.1088887f, -0.0315369f, -0.4702122f, 0.4170579f, 0.0149554f, -2.9016135f, 0.8456838f, -1.4586754f, -0.3096066f, 0.3871936f, -0.0811070f, -0.0972313f, 1.3539683f, -0.4489806f, 2.1791372f, -0.0245313f, -0.5678235f, 0.0153700f, 0.0444115f, -1.1144291f, -0.0992134f, -0.0615626f, -1.5467011f, 0.3384680f, -0.2377923f, -3.0146224f, -0.1680345f, -0.0683730f },\n",
0879 "{ 17.0918045f, 0.3651829f, 0.4609411f, -8.8885517f, -1.3358241f, 0.3141162f, 0.5917582f, -3.1579833f, 18.4088402f, -0.7021288f, -0.1767638f, -20.7704201f, 1.0183893f, -12.4671431f, 0.0741675f, 0.2120477f, -0.4298171f, -0.3993036f, -0.4320501f, 0.1840025f, 35.6576691f, -19.6535892f, -1.1798559f, -4.8292837f, 0.1928347f, -2.1754487f, 5.7580199f, -0.4750445f, -0.0005913f, -3.2222869f, -0.0762974f, -0.0288493f },\n",
0880 "};\n",
0881 "\n",
0882 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer2[32] = {\n",
0883 "-2.9135964f, 2.8539734f, 0.1411822f, 1.3484665f, 4.0785451f, 3.1302166f, 0.6947064f, -6.6154590f, 0.5603381f, 1.1026498f, -0.0329598f, 5.5717101f, -4.4454126f, 0.8731447f, -0.0039664f, 3.3978446f, 0.2816379f, 1.0174516f, 5.7364573f, -0.2107503f, 1.5612308f, 3.0443959f, -0.5723908f, -2.2100310f, 3.2695763f, 1.2092905f, 0.3386045f, 2.9886062f, 1.6525602f, 2.2572896f, 1.5943902f, 3.8117774f };\n",
0884 "\n",
0885 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer2[32][32] = {\n",
0886 "{ -34.7392120f, -3.6741450f, 0.9912871f, 2.2665858f, -4.2091088f, -5.7695718f, -0.7361242f, 1.9049225f, -2.0010724f, 0.5481573f, -0.1977228f, -8.8578167f, 0.2730211f, 0.2286723f, -0.2417704f, 3.0146859f, -4.6289725f, -1.9597307f, -3.8090327f, -0.0440762f, -0.1491523f, 7.2579861f, -1.3858653f, 5.7306480f, 1.3374176f, -58.5746918f, -0.7968796f, -7.3152990f, 0.8867928f, -1.0549910f, -1.5279640f, -0.3361593f },\n",
0887 "{ 0.1808980f, -0.3280856f, 0.2590686f, -1.0191135f, -0.4207653f, -0.7095064f, 0.1365926f, 0.4096814f, -10.8506107f, 1.2207726f, -0.1571288f, -0.7078422f, 5.8424225f, 1.6238457f, -0.0862379f, -1.4228019f, 0.4024760f, 1.8243361f, -0.6990436f, 0.0060351f, -0.1591629f, -4.6918707f, -0.1512203f, 0.0915786f, -0.5897132f, -0.2044267f, -0.9305406f, -0.5472224f, -2.8864694f, 1.0543642f, 0.5347654f, -0.5453511f },\n",
0888 "{ 0.0880480f, 0.4925799f, -2.0632777f, 0.7970945f, 0.1767638f, 2.1716123f, -1.6498836f, 2.0251210f, 3.4944487f, 4.9597678f, 0.0147583f, 1.4637951f, -1.9777663f, -0.4427343f, -0.2699612f, 0.2240291f, 1.4522513f, 4.8733187f, 1.2140924f, -0.1049215f, -0.6172752f, -1.0545660f, -0.0994265f, 0.0489118f, 0.1526600f, -0.5299308f, 1.2835877f, 1.1346543f, 3.8537612f, 0.0424094f, 0.6743001f, 3.6908419f },\n",
0889 "{ 13.5742407f, -0.2112840f, -0.1391970f, -0.2532290f, -6.1313200f, -0.0358473f, -1.1182036f, -1.2637638f, -0.9513455f, -1.3301764f, -0.0799558f, -1.3724291f, -1.6102096f, -0.0269820f, -0.2672046f, -4.3722496f, 0.5367125f, -0.8932595f, 2.7968166f, -0.1038471f, 0.9354176f, 11.1879988f, 2.5031428f, -0.5350324f, 0.9403298f, -23.2240810f, -1.7667898f, -0.9982644f, 0.7523674f, 0.1651796f, 0.6430404f, 0.1910793f },\n",
0890 "{ 0.2658246f, 0.7314553f, -0.1396945f, 0.8241134f, -0.0880722f, 0.5978789f, 0.6997685f, -0.6765627f, 2.1706738f, -0.8361687f, -0.0501303f, 0.3658123f, -6.4972825f, -4.6433911f, -0.0610007f, -1.3655313f, 0.8686107f, -0.5963267f, 0.0234031f, -0.2062150f, -0.2529819f, -8.4623156f, -0.4437916f, -1.9638975f, -0.3870836f, -0.5850767f, -4.6490083f, 0.6830147f, -1.2122021f, -2.5658479f, 0.4557288f, -0.0879869f },\n",
0891 "{ -0.0108578f, 0.5193277f, 0.8980315f, 0.6727697f, -0.2764051f, 0.5845009f, 0.2578916f, 0.5809379f, 0.1509607f, 0.6870583f, -0.1276244f, -0.0574215f, -4.9936194f, 0.5157762f, -0.2816216f, -1.3928800f, 0.4075674f, 1.1707770f, 0.0359891f, -0.0638050f, -0.7958111f, -13.7506857f, 0.9970800f, 0.3291980f, -0.0093937f, -0.8198650f, 2.4517658f, 0.0728102f, 0.5025729f, 1.1547782f, -2.2815955f, 0.7980155f },\n",
0892 "{ -39.5117416f, -1.9099823f, -0.5987001f, -0.1347036f, 0.2869921f, -6.9140315f, 1.1130731f, 1.5080519f, 2.2398031f, -0.0088237f, -0.0412525f, -6.1471343f, 2.8918502f, -0.3072237f, -0.1006490f, -1.7464429f, -0.8411779f, -42.2267303f, 0.7191703f, -0.1617323f, -0.2711441f, -0.8011407f, 0.5901257f, 1.3762408f, 0.3728105f, 0.5847842f, -0.7151731f, -5.5479650f, 0.4756111f, 0.1937995f, -0.2577415f, 1.4683112f },\n",
0893 "{ 5.5796323f, -1.9758234f, -3.9855742f, -1.1933511f, -6.1229320f, -3.4628909f, 0.8108729f, 1.1214327f, 0.2925960f, -0.3177399f, -0.0674356f, -7.5650120f, -0.2327600f, -0.1809894f, -0.2001229f, -3.7608955f, -1.6911368f, -1.2986628f, -10.9993305f, -0.0333699f, 2.1555903f, 4.0816026f, -1.7761034f, 2.7839565f, 1.7967551f, -1.9129136f, 0.3337919f, -5.5511317f, -0.3778784f, -1.7885938f, 0.5358973f, 0.3659674f },\n",
0894 "{ -0.4947478f, -1.1247686f, -0.5239698f, -0.2116186f, 0.9495451f, -1.7275617f, 3.2058396f, 5.8543334f, 1.1329687f, 2.0362375f, -0.0113123f, -3.6117337f, -2.6473372f, -101.7068253f, 0.0933132f, -9.5351706f, 2.7397399f, -6.7515664f, 6.8541646f, -0.1772990f, -8.0234919f, -11.7267113f, -79.2402649f, 1.0556825f, -1.2873408f, -204.0993347f, 2.9009213f, -1.5086528f, -0.3790632f, -16.6444530f, -0.7742234f, -2.5006552f },\n",
0895 "{ -3.5879157f, 1.2350843f, 0.9988346f, 0.8656718f, -0.2288211f, 1.5041313f, 1.2635980f, -0.9575849f, 1.4612056f, -0.6694647f, -0.2607012f, 2.2113328f, -0.1756403f, 0.4053871f, -0.2502097f, -0.9793193f, 1.1052762f, 0.1973925f, -0.6196574f, -0.1616422f, -0.0878458f, -33.2044487f, -0.6067256f, -2.6461327f, -0.8652852f, 0.4787520f, 0.0987720f, 2.0803480f, -2.9416194f, 0.6944812f, -0.7692280f, -0.2723561f },\n",
0896 "{ 0.9679711f, -0.8438777f, 1.6422297f, 0.1470843f, -0.0370339f, -1.7256715f, -3.5553441f, 0.9303662f, -3.7400429f, -0.0415591f, -0.0421884f, -0.4996702f, 0.2560562f, -1.1471199f, -0.1010781f, 2.9399052f, -2.7095137f, 0.1029265f, -0.9041355f, -0.0961986f, -1.1163988f, 2.2019930f, 1.1136653f, 2.6281602f, 0.1038788f, 0.8405868f, -1.7726818f, -1.4164388f, 2.5241690f, -0.4696356f, -0.6558685f, -0.3892841f },\n",
0897 "{ 0.3616795f, -0.2340108f, 1.3857211f, -0.1612080f, 0.9073119f, 2.0731826f, -2.9476249f, -0.8346572f, -0.2136685f, -1.5312109f, -0.2896508f, 3.2546318f, -0.8377571f, -0.1631008f, -0.2456917f, 2.3952994f, 0.7205132f, 0.3009275f, -0.8928477f, 0.0297560f, -0.9989902f, -5.8525624f, -1.1516922f, -1.4669514f, -0.9353167f, 0.0053586f, -0.2916538f, 2.6023183f, -0.3268352f, 1.1982346f, 0.5936528f, 0.2427792f },\n",
0898 "{ -8.4515934f, 1.6914611f, 1.8740683f, -3.0301368f, 1.1796944f, 1.6180993f, 1.5193824f, -1.3537951f, 1.8400136f, -2.8482721f, -0.1556847f, 1.9412404f, 4.7879019f, 0.7858805f, 0.0379962f, -2.2685356f, 2.4616067f, -3.1357207f, -3.1058917f, -0.0751680f, -6.3824587f, -6.4688945f, -0.4577174f, -4.3322401f, -5.3993649f, -0.0399034f, 0.6397164f, 2.3432400f, -3.4454954f, 0.4422087f, 2.4481769f, -2.0406303f },\n",
0899 "{ 2.0977514f, 2.0378633f, -5.0659881f, -3.1632454f, -3.7596478f, 1.1832924f, 2.6347756f, -0.7441098f, 0.9281505f, 0.2330403f, -0.1830614f, -0.7371528f, 0.9002755f, 0.2897577f, 0.0216177f, -6.7740455f, 2.5610964f, 0.7834055f, -7.0665541f, -0.0497221f, 2.1334522f, -4.0962648f, -0.9172248f, -1.9944772f, 0.2762617f, -7.6493464f, 1.3044875f, 0.1891927f, -1.2570183f, -0.9203181f, 1.1330502f, 0.2542208f },\n",
0900 "{ -3.2519467f, 0.1621506f, 1.2959863f, -0.4015787f, -0.1872185f, -1.3774366f, -0.6622453f, -0.9865818f, -0.7389784f, -2.5770009f, -0.1067092f, -1.8156646f, 8.1435080f, 0.1961913f, -0.2026473f, -0.5137898f, -0.3847715f, -2.7479975f, -0.5669084f, -0.0805389f, 0.8162143f, -89.8910904f, -0.0983714f, -0.6288267f, -0.1096447f, 0.3820810f, -0.5419686f, -1.2580637f, -0.1827327f, -0.7081568f, -0.0991779f, -1.8233751f },\n",
0901 "{ -0.1124980f, -0.0672807f, -0.1147610f, 0.0497073f, 0.1218153f, 0.1264220f, 0.0282118f, 0.1016580f, 0.0696730f, -0.0004516f, 0.0812953f, 0.0875243f, 0.0982849f, 0.1660293f, -0.1481024f, -0.0315878f, -0.0644747f, 0.1398649f, 0.0835874f, -0.1440294f, -0.1390193f, -0.0628670f, -0.1517032f, -0.0325693f, -0.1094708f, 0.0963070f, 0.0056602f, -0.0197677f, -0.0068012f, 0.1578562f, -0.0302607f, 0.0684079f },\n",
0902 "{ 0.2561692f, 0.6044238f, -1.0067526f, 3.0207021f, -0.5215638f, -1.5455573f, -2.4320843f, -0.2874290f, -5.5609145f, -2.5270512f, -0.1884816f, -1.4440522f, 1.1501647f, 1.2767099f, -0.2626259f, 1.5462712f, -0.3342260f, -1.4259053f, -0.1591775f, -0.1777169f, -0.8070273f, -0.6262614f, -0.1421982f, 0.4950019f, 0.3899588f, 1.1158458f, -1.8252333f, -1.4090738f, -0.9128270f, 1.2212859f, -0.5060611f, -1.5151979f },\n",
0903 "{ 6.7118378f, -1.9413553f, -0.9765822f, 1.9195900f, -1.6302279f, -3.3607423f, -0.5215842f, 1.6841047f, -3.7323046f, 2.9130237f, -0.2912694f, -1.6349151f, -2.9017780f, 0.8473414f, -0.0895011f, 1.9765211f, -1.6982929f, 1.3711845f, -0.8770422f, -0.0966949f, -4.9838910f, -80.1754532f, 0.5617938f, 5.0638437f, -2.1738896f, 1.1080216f, -1.2562524f, -2.4832511f, 1.9475292f, -0.0768876f, -1.7405595f, 1.2659055f },\n",
0904 "{ -9.3639793f, -0.7770892f, 0.5863696f, -2.4971461f, 0.9785317f, 0.6006397f, -0.7508336f, 1.4561496f, 3.3278019f, -1.3552047f, 0.0016642f, -1.0065420f, -0.9822129f, -0.1398876f, -0.0867651f, -2.2540245f, 1.3095651f, -0.4880162f, -0.9081959f, 0.0172203f, -0.9673725f, -6.6905494f, 1.8820195f, -0.3343536f, -0.9252734f, -0.4198346f, 3.2226353f, 0.0417345f, -0.2720280f, -0.4798162f, 0.8319970f, 0.6051134f },\n",
0905 "{ 0.0971739f, -0.1656290f, -3.0162272f, 1.3674160f, -0.1130774f, -2.0347462f, 0.8287644f, 1.7546188f, -0.8183546f, -0.2517924f, -0.0338358f, -2.3002670f, -1.8175446f, -0.8929357f, 0.0145055f, 0.7129369f, -1.2497267f, -0.8694588f, 0.6608886f, 0.0472852f, 0.7463530f, 2.0970786f, -0.7406388f, 1.6379862f, 0.9597634f, 0.2664887f, -0.9611518f, -2.1256742f, -0.8108851f, -0.7670876f, -0.2143202f, -0.9296412f },\n",
0906 "{ -5.5468359f, 1.8933475f, 0.5377558f, -0.8609743f, 1.6522382f, 3.7070215f, -1.3910099f, -1.7996042f, -1.2547578f, -0.3161051f, -0.1433857f, 5.9167895f, 0.2788936f, 0.6513762f, -0.1890229f, 1.3976518f, 2.5647829f, 1.7091097f, -2.4891980f, 0.0704016f, -1.2354076f, -7.9673457f, 0.5024176f, -3.1194675f, -1.8552034f, 1.4241451f, -0.5721908f, 4.6941881f, -0.4191083f, 1.5897839f, 0.5376836f, -0.3906031f },\n",
0907 "{ -7.3998523f, -1.7208674f, 3.6660385f, 3.2399852f, 2.6726487f, -2.7743144f, 2.6148691f, 5.4286017f, -3.5616024f, 3.8747442f, -0.2854572f, -1.7255206f, 1.5527865f, -95.2269287f, -0.0130005f, -0.3984787f, -0.3650612f, -5.9493575f, 4.3472433f, 0.0598797f, -11.5429420f, -2.9780169f, -69.4482956f, 3.8544486f, -2.9926283f, 3.9207959f, 0.2614698f, -1.5368384f, 1.2026052f, -10.9552374f, -2.5336740f, -4.0654378f },\n",
0908 "{ 2.1074338f, -2.0316105f, -0.2519890f, 0.0232255f, -0.2889173f, -0.1693905f, 0.4285150f, 0.6449150f, 4.6293564f, 3.3936121f, 0.0660587f, 0.3134870f, -2.3245549f, -0.9685450f, 0.0889201f, 0.5934311f, -0.9143091f, 2.3384421f, -0.4089483f, -0.1643694f, 2.5919075f, 6.0844655f, 0.2091536f, 2.0565152f, 1.1226704f, -0.2695110f, 0.3927289f, -0.0457220f, 4.0436058f, -1.5475131f, -1.3438802f, 1.6676767f },\n",
0909 "{ 1.7787290f, 0.3969800f, 1.7834123f, 1.3779058f, -0.5738219f, -0.5349790f, -1.4947498f, -0.0787759f, 0.0341407f, 0.4346433f, -0.1981957f, -0.2886125f, -1.0133898f, -0.7178908f, -0.1872994f, 0.7944858f, -1.4787679f, -0.2754479f, -3.5224824f, -0.2090070f, -1.1161381f, -3.6711519f, -1.7022727f, 0.1558363f, 0.4152904f, -2.6727507f, 1.0731899f, -0.3006089f, 0.1950178f, -1.4062581f, -1.4458562f, 0.2443156f },\n",
0910 "{ -0.1570787f, 0.0073413f, -0.1335499f, -0.1712597f, 0.0029127f, 0.1628582f, -0.0816609f, -0.1307715f, -0.1621102f, -0.1200016f, -0.1394555f, 0.0157797f, -0.1572116f, 0.1745119f, -0.1128801f, -0.0566642f, 0.0099119f, -0.1222350f, 0.0299575f, -0.1031234f, -0.1048335f, 0.1707117f, -0.1490631f, -0.0835587f, -0.1712185f, -0.1278749f, 0.1462234f, -0.0081762f, -0.1106477f, -0.1645385f, 0.1268658f, 0.1686065f },\n",
0911 "{ -2.2590189f, -0.1024268f, -1.9020500f, -0.7051241f, 0.1037211f, 0.1701183f, 0.1889226f, 0.9506961f, 1.4137075f, 0.4496680f, -0.2055015f, -0.5990168f, -6.5227470f, -0.1113795f, -0.1070101f, 0.0105921f, -0.1653819f, 0.8838411f, -0.4713951f, 0.0250525f, 0.5694079f, -63.6874619f, 0.3740432f, 0.2925327f, 0.2328388f, -0.9265066f, 0.3290201f, -0.3581912f, 0.8044130f, -0.0143339f, 0.6609334f, -0.6653876f },\n",
0912 "{ 1.4302264f, 0.2180429f, 0.9684587f, 1.0369793f, 0.1597588f, -0.7066790f, -1.7150335f, 0.1960071f, -0.1694691f, 0.8381965f, -0.0181254f, -1.8366945f, -1.8840518f, -0.3109443f, -0.0058080f, 2.0794308f, -1.7936089f, -0.4478118f, -1.2889514f, -0.0300996f, -0.5915933f, -0.8868528f, 1.2223560f, 0.6542705f, 0.0814525f, -1.3704894f, -0.1875549f, -1.6079675f, -0.2744675f, 0.0382733f, -0.9821799f, -1.1006635f },\n",
0913 "{ -0.3911386f, 0.0468989f, 1.9009087f, -1.6725038f, 0.4506643f, -1.9519551f, 0.8855276f, -1.5861524f, 0.3190332f, -3.1724985f, -0.0278042f, -1.2427157f, 1.6820471f, 0.1633015f, -0.0449006f, -1.6101545f, -0.1007412f, -2.7659879f, -0.5162025f, -0.1431058f, 0.8236251f, -0.9194202f, -0.1490582f, -1.6231275f, -0.5467592f, 0.1333764f, -0.4865468f, -0.8269796f, -0.9018597f, 0.0288408f, -1.0994427f, -2.7987468f },\n",
0914 "{ 0.1278387f, -0.0134571f, 0.0448454f, -0.1556552f, -0.1247998f, -0.1196313f, -0.1611872f, -0.0630336f, -0.1410615f, 0.1682803f, 0.0263861f, 0.0619852f, 0.0423437f, 0.0982059f, 0.0784064f, 0.1412098f, 0.0331818f, -0.1537199f, -0.1165343f, -0.0441304f, 0.0197925f, -0.1256299f, 0.0694126f, -0.0137156f, 0.1587864f, 0.1131037f, -0.0722358f, 0.1287198f, -0.0683723f, -0.1212666f, 0.0847685f, 0.1469466f },\n",
0915 "{ -1.7708958f, 0.2500555f, -1.0356978f, -1.4820993f, -0.9565629f, 2.2127695f, 1.3409303f, -0.6528702f, 1.4306690f, 1.7529922f, 0.0491593f, 0.8595024f, -1.1016617f, -0.1608696f, -0.1200257f, -1.9610568f, 2.6189950f, 1.8707203f, 0.5241567f, -0.2288505f, 0.1528303f, -127.8296814f, -166.2449646f, -1.0174949f, -0.6682033f, -0.9169813f, 1.5756097f, 1.4574182f, -0.1463246f, -1.8262713f, 0.7517605f, -0.1181977f },\n",
0916 "{ -0.0324178f, -0.0418596f, -0.1287051f, -0.0232098f, -0.0512466f, -0.0905093f, -0.1104402f, -0.0095842f, 0.1413968f, -0.0081470f, -0.0251773f, 0.0667293f, 0.0344667f, 0.0116366f, -0.0908088f, -0.0980062f, 0.1874590f, -0.0381802f, 0.0684232f, 0.0252469f, -0.0681347f, 0.1034415f, 0.0576827f, -0.0557779f, 0.0868192f, -0.0851723f, -0.0868760f, 0.1192429f, 0.1751331f, -0.0323825f, -0.1238438f, -0.0623215f },\n",
0917 "{ 0.1757070f, -0.1212057f, -0.0878934f, 0.0737142f, 0.0712249f, -0.0818311f, -0.0719173f, -0.0561241f, 0.0630706f, -0.1523757f, -0.0048847f, 0.1597463f, -0.0302248f, -0.0096164f, -0.0259278f, -0.0815664f, -0.1283869f, 0.1644790f, -0.1612884f, 0.1505984f, -0.1614616f, -0.0756450f, -0.1680063f, -0.0716024f, -0.1266488f, 0.1165592f, -0.0066066f, 0.0661669f, 0.0148620f, 0.0464089f, 0.1496351f, -0.1720888f },\n",
0918 "};\n",
0919 "\n",
0920 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_output_layer[3] = {\n",
0921 "-0.3838706f, -0.0366794f, 0.5841699f };\n",
0922 "\n",
0923 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_output_layer[32][3] = {\n",
0924 "{ 0.6237589f, 0.2710748f, 0.5615537f },\n",
0925 "{ -0.1665458f, 0.3942705f, 0.2601272f },\n",
0926 "{ 0.3388835f, 0.1579971f, 0.0178280f },\n",
0927 "{ 0.5823844f, -0.0299621f, 0.1178701f },\n",
0928 "{ 0.5561634f, 0.1805784f, 0.6629463f },\n",
0929 "{ 0.1693098f, -0.8297758f, 0.1556239f },\n",
0930 "{ 0.0062806f, 0.2958559f, 0.2698825f },\n",
0931 "{ -0.3925241f, 0.1489681f, -0.0803940f },\n",
0932 "{ 0.5710047f, 0.1924859f, 0.2375189f },\n",
0933 "{ -0.0372825f, 0.0286687f, 0.2910011f },\n",
0934 "{ -0.0867018f, -0.1508995f, -0.0193411f },\n",
0935 "{ 0.4878173f, -0.9407690f, 0.3869846f },\n",
0936 "{ 0.9613981f, 0.3148000f, 0.2196945f },\n",
0937 "{ 0.5831478f, 1.2141191f, 0.7358299f },\n",
0938 "{ -0.0073579f, -0.0419888f, 0.0338354f },\n",
0939 "{ 0.2477632f, 0.9092489f, 0.7818094f },\n",
0940 "{ 0.3554717f, -0.4452990f, 0.0102171f },\n",
0941 "{ 0.3888267f, 0.7089493f, 0.3766315f },\n",
0942 "{ 0.8450955f, -0.0079020f, 0.5853269f },\n",
0943 "{ 0.0646952f, 0.0271975f, 0.0329916f },\n",
0944 "{ 0.5528679f, 0.0075829f, 0.2414524f },\n",
0945 "{ -1.3869698f, -1.1617719f, -1.1356672f },\n",
0946 "{ 0.0214099f, 0.3563140f, 0.5346315f },\n",
0947 "{ 0.3791857f, -0.2714695f, -0.0823861f },\n",
0948 "{ -0.3221727f, 0.5334318f, 0.1581419f },\n",
0949 "{ 0.6678535f, 0.6672282f, 0.4110478f },\n",
0950 "{ 0.1442596f, 0.0245941f, -0.1659890f },\n",
0951 "{ -0.9674007f, 1.4712439f, -0.8418093f },\n",
0952 "{ 0.5696401f, 0.2636259f, 0.2079044f },\n",
0953 "{ 0.0382360f, 0.2687068f, 0.4462553f },\n",
0954 "{ -0.0957586f, 0.4259349f, 0.3613387f },\n",
0955 "{ -0.0633585f, 0.4451550f, 0.2848748f },\n",
0956 "};\n",
0957 "\n"
0958 ]
0959 }
0960 ],
0961 "source": [
0962 "def print_formatted_weights_biases(weights, biases, layer_name):\n",
0963 " # Print biases\n",
0964 " print(f\"ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_{layer_name}[{len(biases)}] = {{\")\n",
0965 " print(\", \".join(f\"{b:.7f}f\" for b in biases) + \" };\")\n",
0966 " print()\n",
0967 "\n",
0968 " # Print weights\n",
0969 " print(f\"ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_{layer_name}[{len(weights[0])}][{len(weights)}] = {{\")\n",
0970 " for row in weights.T:\n",
0971 " formatted_row = \", \".join(f\"{w:.7f}f\" for w in row)\n",
0972 " print(f\"{{ {formatted_row} }},\")\n",
0973 " print(\"};\")\n",
0974 " print()\n",
0975 "\n",
0976 "def print_model_weights_biases(model):\n",
0977 " # Make sure the model is in evaluation mode\n",
0978 " model.eval()\n",
0979 "\n",
0980 " # Iterate through all named modules in the model\n",
0981 " for name, module in model.named_modules():\n",
0982 " # Check if the module is a linear layer\n",
0983 " if isinstance(module, nn.Linear):\n",
0984 " # Get weights and biases\n",
0985 " weights = module.weight.data.cpu().numpy()\n",
0986 " biases = module.bias.data.cpu().numpy()\n",
0987 "\n",
0988 " # Print formatted weights and biases\n",
0989 " print_formatted_weights_biases(weights, biases, name.replace('.', '_'))\n",
0990 "\n",
0991 "print_model_weights_biases(model)\n"
0992 ]
0993 },
0994 {
0995 "cell_type": "code",
0996 "execution_count": 9,
0997 "metadata": {},
0998 "outputs": [],
0999 "source": [
1000 "# Ensure input_features_tensor is moved to the appropriate device\n",
1001 "input_features_tensor = input_features_tensor.to(device)\n",
1002 "\n",
1003 "# Make predictions\n",
1004 "with torch.no_grad():\n",
1005 " model.eval()\n",
1006 " outputs = model(input_features_tensor)\n",
1007 " predictions = outputs.squeeze().cpu().numpy()\n",
1008 "\n",
1009 "full_tracks = (np.concatenate(branches['t3_pMatched']) > 0.95)\n",
1010 "\n",
1011 "t3_pt = np.concatenate(branches['t3_radius']) * 2 * (2.99792458e-3 * 3.8) / 2"
1012 ]
1013 },
1014 {
1015 "cell_type": "code",
1016 "execution_count": 10,
1017 "metadata": {},
1018 "outputs": [
1019 {
1020 "name": "stdout",
1021 "output_type": "stream",
1022 "text": [
1023 "Eta bin 0.00-0.25: 9409714 fakes, 313231 true Prompt\n",
1024 "Eta bin 0.25-0.50: 9242595 fakes, 323051 true Prompt\n",
1025 "Eta bin 0.50-0.75: 7849380 fakes, 410185 true Prompt\n",
1026 "Eta bin 0.75-1.00: 4293980 fakes, 322065 true Prompt\n",
1027 "Eta bin 1.00-1.25: 4343023 fakes, 374215 true Prompt\n",
1028 "Eta bin 1.25-1.50: 2725728 fakes, 351420 true Prompt\n",
1029 "Eta bin 1.50-1.75: 1368266 fakes, 425819 true Prompt\n",
1030 "Eta bin 1.75-2.00: 1413754 fakes, 467604 true Prompt\n",
1031 "Eta bin 2.00-2.25: 448439 fakes, 419450 true Prompt\n",
1032 "Eta bin 2.25-2.50: 124212 fakes, 247704 true Prompt\n"
1033 ]
1034 },
1035 {
1036 "data": {
1037 "text/plain": [
1038 "<Figure size 2000x800 with 3 Axes>"
1039 ]
1040 },
1041 "metadata": {},
1042 "output_type": "display_data"
1043 },
1044 {
1045 "name": "stdout",
1046 "output_type": "stream",
1047 "text": [
1048 "\n",
1049 "Prompt tracks, pt: 0.0 to 5.0 GeV\n",
1050 "Number of true prompt tracks: 3654744\n",
1051 "Number of fake tracks in pt bin: 41219091\n",
1052 "\n",
1053 "80% Retention Cut Values: {0.6326, 0.6415, 0.6506, 0.6526, 0.5321, 0.5651, 0.5747, 0.5765, 0.6243, 0.6320} Mean: 0.6082\n",
1054 "80% Cut Fake Rejections: {99.1, 99.1, 99.1, 99.2, 97.4, 97.1, 96.8, 96.8, 95.6, 92.4} Mean: 97.3%\n",
1055 "\n",
1056 "90% Retention Cut Values: {0.4957, 0.5052, 0.5201, 0.5340, 0.4275, 0.4708, 0.4890, 0.4932, 0.5400, 0.5449} Mean: 0.502\n",
1057 "90% Cut Fake Rejections: {98.5, 98.6, 98.6, 98.7, 95.5, 95.4, 95.0, 95.0, 93.1, 88.9} Mean: 95.7%\n",
1058 "\n",
1059 "93% Retention Cut Values: {0.3874, 0.3978, 0.4201, 0.4502, 0.3750, 0.4117, 0.4394, 0.4466, 0.4916, 0.4899} Mean: 0.431\n",
1060 "93% Cut Fake Rejections: {98.1, 98.2, 98.2, 98.3, 94.5, 94.3, 94.0, 94.1, 91.8, 86.8} Mean: 94.8%\n",
1061 "\n",
1062 "96% Retention Cut Values: {0.1332, 0.1616, 0.2094, 0.2898, 0.2909, 0.3109, 0.3561, 0.3626, 0.4016, 0.3775} Mean: 0.2894\n",
1063 "96% Cut Fake Rejections: {96.3, 96.7, 96.9, 97.1, 92.5, 92.4, 92.2, 92.4, 89.6, 82.9} Mean: 92.9%\n",
1064 "\n",
1065 "97% Retention Cut Values: {0.0894, 0.1137, 0.1595, 0.2333, 0.2507, 0.2569, 0.3081, 0.3120, 0.3466, 0.2932} Mean: 0.2363\n",
1066 "97% Cut Fake Rejections: {94.6, 95.4, 95.8, 95.9, 91.4, 91.1, 91.2, 91.4, 88.3, 79.7} Mean: 91.5%\n",
1067 "\n",
1068 "98% Retention Cut Values: {0.0595, 0.0742, 0.1129, 0.1736, 0.1919, 0.1824, 0.2337, 0.2285, 0.2529, 0.1681} Mean: 0.1678\n",
1069 "98% Cut Fake Rejections: {91.9, 93.1, 93.9, 94.0, 89.5, 89.0, 89.3, 89.5, 85.8, 73.7} Mean: 89.0%\n",
1070 "\n",
1071 "99% Retention Cut Values: {0.0319, 0.0371, 0.0493, 0.0943, 0.0923, 0.0767, 0.1061, 0.1041, 0.1124, 0.0762} Mean: 0.078\n",
1072 "99% Cut Fake Rejections: {86.5, 88.4, 88.8, 89.8, 84.8, 83.9, 84.3, 84.5, 79.6, 62.8} Mean: 83.3%\n",
1073 "\n",
1074 "99.5% Retention Cut Values: {0.0149, 0.0168, 0.0242, 0.0402, 0.0389, 0.0365, 0.0527, 0.0546, 0.0615, 0.0366} Mean: 0.0377\n",
1075 "99.5% Cut Fake Rejections: {78.7, 81.5, 83.3, 84.1, 79.2, 78.5, 78.6, 79.3, 73.6, 49.7} Mean: 76.7%\n",
1076 "Eta bin 0.00-0.25: 9409714 fakes, 43220 true Displaced\n",
1077 "Eta bin 0.25-0.50: 9242595 fakes, 47035 true Displaced\n",
1078 "Eta bin 0.50-0.75: 7849380 fakes, 62690 true Displaced\n",
1079 "Eta bin 0.75-1.00: 4293980 fakes, 52590 true Displaced\n",
1080 "Eta bin 1.00-1.25: 4343023 fakes, 62242 true Displaced\n",
1081 "Eta bin 1.25-1.50: 2725728 fakes, 59777 true Displaced\n",
1082 "Eta bin 1.50-1.75: 1368266 fakes, 76741 true Displaced\n",
1083 "Eta bin 1.75-2.00: 1413754 fakes, 90436 true Displaced\n",
1084 "Eta bin 2.00-2.25: 448439 fakes, 73564 true Displaced\n",
1085 "Eta bin 2.25-2.50: 124212 fakes, 43525 true Displaced\n"
1086 ]
1087 },
1088 {
1089 "data": {
1090 "text/plain": [
1091 "<Figure size 2000x800 with 3 Axes>"
1092 ]
1093 },
1094 "metadata": {},
1095 "output_type": "display_data"
1096 },
1097 {
1098 "name": "stdout",
1099 "output_type": "stream",
1100 "text": [
1101 "\n",
1102 "Displaced tracks, pt: 0.0 to 5.0 GeV\n",
1103 "Number of true displaced tracks: 611820\n",
1104 "Number of fake tracks in pt bin: 41219091\n",
1105 "\n",
1106 "80% Retention Cut Values: {0.2362, 0.2377, 0.2496, 0.2579, 0.2940, 0.3057, 0.3273, 0.3224, 0.2944, 0.2890} Mean: 0.2814\n",
1107 "80% Cut Fake Rejections: {83.1, 81.5, 78.7, 76.4, 75.6, 68.3, 54.0, 49.9, 34.7, 18.3} Mean: 62.1%\n",
1108 "\n",
1109 "90% Retention Cut Values: {0.1809, 0.1855, 0.1998, 0.2098, 0.2173, 0.2267, 0.2391, 0.2507, 0.2356, 0.2317} Mean: 0.2177\n",
1110 "90% Cut Fake Rejections: {79.7, 78.0, 75.4, 73.2, 70.1, 61.5, 46.5, 43.7, 29.3, 12.4} Mean: 57.0%\n",
1111 "\n",
1112 "93% Retention Cut Values: {0.1527, 0.1622, 0.1814, 0.1933, 0.1952, 0.1889, 0.2041, 0.2280, 0.2216, 0.2166} Mean: 0.1944\n",
1113 "93% Cut Fake Rejections: {77.6, 76.1, 74.0, 72.1, 68.4, 58.2, 43.5, 41.8, 28.0, 10.8} Mean: 55.0%\n",
1114 "\n",
1115 "96% Retention Cut Values: {0.1133, 0.1274, 0.1514, 0.1662, 0.1657, 0.1576, 0.1601, 0.2030, 0.2068, 0.2028} Mean: 0.1654\n",
1116 "96% Cut Fake Rejections: {73.9, 72.8, 71.5, 70.0, 65.9, 55.1, 39.2, 39.6, 26.6, 9.6} Mean: 52.4%\n",
1117 "\n",
1118 "97% Retention Cut Values: {0.0956, 0.1131, 0.1362, 0.1509, 0.1549, 0.1459, 0.1479, 0.1924, 0.2012, 0.1980} Mean: 0.1536\n",
1119 "97% Cut Fake Rejections: {71.7, 71.2, 70.0, 68.7, 65.0, 53.9, 37.8, 38.7, 26.1, 9.1} Mean: 51.2%\n",
1120 "\n",
1121 "98% Retention Cut Values: {0.0686, 0.0874, 0.1165, 0.1298, 0.1400, 0.1311, 0.1327, 0.1773, 0.1954, 0.1922} Mean: 0.1371\n",
1122 "98% Cut Fake Rejections: {67.4, 67.7, 68.0, 66.8, 63.6, 52.2, 36.0, 37.4, 25.6, 8.6} Mean: 49.3%\n",
1123 "\n",
1124 "99% Retention Cut Values: {0.0334, 0.0504, 0.0748, 0.0994, 0.1128, 0.1123, 0.1118, 0.1525, 0.1867, 0.1847} Mean: 0.1119\n",
1125 "99% Cut Fake Rejections: {57.9, 60.3, 62.3, 63.5, 60.8, 49.8, 33.2, 35.1, 24.8, 7.9} Mean: 45.6%\n",
1126 "\n",
1127 "99.5% Retention Cut Values: {0.0155, 0.0237, 0.0414, 0.0560, 0.0808, 0.0901, 0.0960, 0.1335, 0.1789, 0.1791} Mean: 0.0895\n",
1128 "99.5% Cut Fake Rejections: {47.6, 50.3, 54.7, 56.5, 56.7, 46.6, 30.8, 33.2, 24.1, 7.5} Mean: 40.8%\n",
1129 "Eta bin 0.00-0.25: 2012526 fakes, 6249 true Prompt\n",
1130 "Eta bin 0.25-0.50: 1972121 fakes, 6496 true Prompt\n",
1131 "Eta bin 0.50-0.75: 1704510 fakes, 6894 true Prompt\n",
1132 "Eta bin 0.75-1.00: 930629 fakes, 5318 true Prompt\n",
1133 "Eta bin 1.00-1.25: 861320 fakes, 9397 true Prompt\n",
1134 "Eta bin 1.25-1.50: 523329 fakes, 14695 true Prompt\n",
1135 "Eta bin 1.50-1.75: 246635 fakes, 24265 true Prompt\n",
1136 "Eta bin 1.75-2.00: 250585 fakes, 15787 true Prompt\n",
1137 "Eta bin 2.00-2.25: 86204 fakes, 6652 true Prompt\n",
1138 "Eta bin 2.25-2.50: 22080 fakes, 3385 true Prompt\n"
1139 ]
1140 },
1141 {
1142 "data": {
1143 "text/plain": [
1144 "<Figure size 2000x800 with 3 Axes>"
1145 ]
1146 },
1147 "metadata": {},
1148 "output_type": "display_data"
1149 },
1150 {
1151 "name": "stdout",
1152 "output_type": "stream",
1153 "text": [
1154 "\n",
1155 "Prompt tracks, pt: 5.0 to inf GeV\n",
1156 "Number of true prompt tracks: 99138\n",
1157 "Number of fake tracks in pt bin: 8609939\n",
1158 "\n",
1159 "80% Retention Cut Values: {0.1353, 0.1569, 0.2083, 0.2407, 0.2910, 0.3552, 0.4221, 0.4197, 0.3174, 0.2681} Mean: 0.2815\n",
1160 "80% Cut Fake Rejections: {98.8, 99.0, 99.3, 99.4, 96.1, 94.8, 93.5, 93.7, 88.5, 74.4} Mean: 93.7%\n",
1161 "\n",
1162 "90% Retention Cut Values: {0.0302, 0.0415, 0.0994, 0.1791, 0.1960, 0.2467, 0.3227, 0.3242, 0.2367, 0.2187} Mean: 0.1895\n",
1163 "90% Cut Fake Rejections: {90.2, 94.5, 98.1, 99.0, 93.4, 91.4, 89.8, 89.7, 80.1, 63.8} Mean: 89.0%\n",
1164 "\n",
1165 "93% Retention Cut Values: {0.0188, 0.0261, 0.0502, 0.1487, 0.1606, 0.2022, 0.2772, 0.2797, 0.2182, 0.2063} Mean: 0.1588\n",
1166 "93% Cut Fake Rejections: {82.0, 89.1, 94.9, 98.7, 91.8, 89.5, 87.9, 87.3, 77.5, 60.9} Mean: 85.9%\n",
1167 "\n",
1168 "96% Retention Cut Values: {0.0095, 0.0138, 0.0243, 0.0620, 0.1096, 0.1450, 0.2184, 0.2276, 0.1949, 0.1769} Mean: 0.1182\n",
1169 "96% Cut Fake Rejections: {67.7, 78.4, 86.6, 95.6, 88.5, 86.0, 84.6, 83.8, 74.2, 54.1} Mean: 79.9%\n",
1170 "\n",
1171 "97% Retention Cut Values: {0.0079, 0.0105, 0.0185, 0.0361, 0.0883, 0.1242, 0.1943, 0.2106, 0.1600, 0.1354} Mean: 0.0986\n",
1172 "97% Cut Fake Rejections: {63.7, 73.1, 82.1, 91.3, 86.5, 84.3, 83.0, 82.5, 69.0, 42.8} Mean: 75.8%\n",
1173 "\n",
1174 "98% Retention Cut Values: {0.0057, 0.0077, 0.0119, 0.0212, 0.0684, 0.0902, 0.1624, 0.1814, 0.1158, 0.1037} Mean: 0.0769\n",
1175 "98% Cut Fake Rejections: {56.9, 66.9, 74.3, 84.8, 84.0, 80.5, 80.6, 79.9, 61.0, 32.8} Mean: 70.2%\n",
1176 "\n",
1177 "99% Retention Cut Values: {0.0027, 0.0045, 0.0074, 0.0119, 0.0325, 0.0393, 0.1192, 0.1170, 0.0683, 0.0819} Mean: 0.0485\n",
1178 "99% Cut Fake Rejections: {43.1, 56.7, 65.6, 76.3, 75.4, 69.5, 76.2, 72.6, 49.8, 25.6} Mean: 61.1%\n",
1179 "\n",
1180 "99.5% Retention Cut Values: {0.0015, 0.0027, 0.0038, 0.0054, 0.0153, 0.0219, 0.0820, 0.0657, 0.0373, 0.0669} Mean: 0.0302\n",
1181 "99.5% Cut Fake Rejections: {34.3, 48.3, 54.4, 63.1, 64.6, 60.9, 70.6, 63.0, 39.3, 20.4} Mean: 51.9%\n",
1182 "Eta bin 0.00-0.25: 2012526 fakes, 2764 true Displaced\n",
1183 "Eta bin 0.25-0.50: 1972121 fakes, 2581 true Displaced\n",
1184 "Eta bin 0.50-0.75: 1704510 fakes, 2477 true Displaced\n",
1185 "Eta bin 0.75-1.00: 930629 fakes, 2122 true Displaced\n",
1186 "Eta bin 1.00-1.25: 861320 fakes, 2780 true Displaced\n",
1187 "Eta bin 1.25-1.50: 523329 fakes, 3481 true Displaced\n",
1188 "Eta bin 1.50-1.75: 246635 fakes, 4701 true Displaced\n",
1189 "Eta bin 1.75-2.00: 250585 fakes, 3009 true Displaced\n",
1190 "Eta bin 2.00-2.25: 86204 fakes, 1579 true Displaced\n",
1191 "Eta bin 2.25-2.50: 22080 fakes, 881 true Displaced\n"
1192 ]
1193 },
1194 {
1195 "data": {
1196 "text/plain": [
1197 "<Figure size 2000x800 with 3 Axes>"
1198 ]
1199 },
1200 "metadata": {},
1201 "output_type": "display_data"
1202 },
1203 {
1204 "name": "stdout",
1205 "output_type": "stream",
1206 "text": [
1207 "\n",
1208 "Displaced tracks, pt: 5.0 to inf GeV\n",
1209 "Number of true displaced tracks: 26375\n",
1210 "Number of fake tracks in pt bin: 8609939\n",
1211 "\n",
1212 "80% Retention Cut Values: {0.5019, 0.4954, 0.5197, 0.5256, 0.4161, 0.3509, 0.3778, 0.3667, 0.4388, 0.4915} Mean: 0.4484\n",
1213 "80% Cut Fake Rejections: {98.6, 98.6, 98.7, 98.7, 97.0, 93.3, 87.6, 86.9, 89.0, 88.0} Mean: 93.6%\n",
1214 "\n",
1215 "90% Retention Cut Values: {0.3265, 0.3967, 0.4493, 0.4656, 0.3459, 0.2910, 0.3486, 0.3356, 0.3693, 0.4448} Mean: 0.3773\n",
1216 "90% Cut Fake Rejections: {97.6, 98.1, 98.2, 98.3, 95.8, 90.8, 85.4, 84.0, 83.3, 82.4} Mean: 91.4%\n",
1217 "\n",
1218 "93% Retention Cut Values: {0.1804, 0.2502, 0.4033, 0.4375, 0.2958, 0.2556, 0.3353, 0.3242, 0.3405, 0.4194} Mean: 0.3242\n",
1219 "93% Cut Fake Rejections: {96.1, 97.0, 97.9, 98.2, 94.6, 89.1, 84.5, 83.1, 80.3, 78.8} Mean: 90.0%\n",
1220 "\n",
1221 "96% Retention Cut Values: {0.0464, 0.0366, 0.2521, 0.3317, 0.2324, 0.2021, 0.3165, 0.2992, 0.3096, 0.3568} Mean: 0.2383\n",
1222 "96% Cut Fake Rejections: {90.2, 88.7, 96.8, 97.5, 92.9, 86.3, 83.3, 81.2, 77.1, 69.3} Mean: 86.3%\n",
1223 "\n",
1224 "97% Retention Cut Values: {0.0270, 0.0225, 0.1590, 0.1765, 0.2111, 0.1570, 0.3039, 0.2869, 0.2999, 0.2103} Mean: 0.1854\n",
1225 "97% Cut Fake Rejections: {85.2, 83.8, 95.5, 96.0, 92.1, 83.3, 82.6, 80.4, 76.2, 43.0} Mean: 81.8%\n",
1226 "\n",
1227 "98% Retention Cut Values: {0.0168, 0.0160, 0.0687, 0.0392, 0.1494, 0.1121, 0.2684, 0.2642, 0.2935, 0.1370} Mean: 0.1365\n",
1228 "98% Cut Fake Rejections: {79.1, 79.4, 92.1, 88.6, 89.4, 79.3, 80.6, 78.7, 75.5, 26.5} Mean: 76.9%\n",
1229 "\n",
1230 "99% Retention Cut Values: {0.0091, 0.0075, 0.0350, 0.0213, 0.0435, 0.0676, 0.1957, 0.1649, 0.1080, 0.1046} Mean: 0.0757\n",
1231 "99% Cut Fake Rejections: {69.6, 67.9, 87.1, 82.3, 77.4, 73.2, 76.2, 70.3, 50.1, 19.1} Mean: 67.3%\n",
1232 "\n",
1233 "99.5% Retention Cut Values: {0.0059, 0.0021, 0.0151, 0.0116, 0.0222, 0.0309, 0.1230, 0.1205, 0.0779, 0.0930} Mean: 0.0502\n",
1234 "99.5% Cut Fake Rejections: {62.0, 47.0, 76.9, 74.2, 69.0, 62.6, 69.7, 64.8, 43.9, 16.5} Mean: 58.7%\n"
1235 ]
1236 }
1237 ],
1238 "source": [
1239 "import numpy as np\n",
1240 "from matplotlib import pyplot as plt\n",
1241 "from matplotlib.colors import LogNorm\n",
1242 "import torch\n",
1243 "\n",
1244 "# Ensure input_features_tensor is on the right device\n",
1245 "input_features_tensor = input_features_tensor.to(device)\n",
1246 "\n",
1247 "# Get model predictions\n",
1248 "with torch.no_grad():\n",
1249 " model.eval()\n",
1250 " outputs = model(input_features_tensor)\n",
1251 " predictions = outputs.cpu().numpy() # Shape will be [n_samples, 3]\n",
1252 "\n",
1253 "# Get track information\n",
1254 "t3_pt = np.concatenate(branches['t3_radius']) * 2 * (2.99792458e-3 * 3.8) / 2\n",
1255 "\n",
1256 "def plot_for_pt_bin(pt_min, pt_max, percentiles, eta_bin_edges, t3_pt, predictions, t3_sim_vxy, eta_list):\n",
1257 " \"\"\"\n",
1258 " Calculate and plot cut values for specified percentiles in a given pt bin, separately for prompt and displaced tracks\n",
1259 " \"\"\"\n",
1260 " # Filter data based on pt bin\n",
1261 " pt_mask = (t3_pt > pt_min) & (t3_pt <= pt_max)\n",
1262 " \n",
1263 " # Get absolute eta values for all tracks in pt bin\n",
1264 " abs_eta = np.abs(eta_list[0][pt_mask])\n",
1265 " \n",
1266 " # Get predictions for all tracks in pt bin\n",
1267 " pred_filtered = predictions[pt_mask]\n",
1268 " \n",
1269 " # Get track types using pMatched and t3_sim_vxy\n",
1270 " matched = (np.concatenate(branches['t3_pMatched']) > 0.95)[pt_mask]\n",
1271 " fake_tracks = (np.concatenate(branches['t3_pMatched']) < 0.75)[pt_mask]\n",
1272 " true_displaced = (t3_sim_vxy[pt_mask] > 0.1) & matched\n",
1273 " true_prompt = ~(t3_sim_vxy[pt_mask] > 0.1) & matched\n",
1274 " \n",
1275 " # Separate plots for prompt and displaced tracks\n",
1276 " for track_type, true_mask, pred_idx, title_suffix in [\n",
1277 " (\"Prompt\", true_prompt, 1, \"Prompt Real Tracks\"),\n",
1278 " (\"Displaced\", true_displaced, 2, \"Displaced Real Tracks\")\n",
1279 " ]:\n",
1280 " # Dictionaries to store values\n",
1281 " cut_values = {p: [] for p in percentiles}\n",
1282 " fake_rejections = {p: [] for p in percentiles}\n",
1283 " \n",
1284 " # Get probabilities for this class\n",
1285 " probs = pred_filtered[:, pred_idx]\n",
1286 " \n",
1287 " # Create two side-by-side plots\n",
1288 " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))\n",
1289 " \n",
1290 " # Plot probability distribution (only for true tracks of this type)\n",
1291 " h = ax1.hist2d(abs_eta[true_mask], \n",
1292 " probs[true_mask], \n",
1293 " bins=[eta_bin_edges, 50], \n",
1294 " norm=LogNorm())\n",
1295 " plt.colorbar(h[3], ax=ax1, label='Counts')\n",
1296 " \n",
1297 " # For each eta bin\n",
1298 " bin_centers = []\n",
1299 " for i in range(len(eta_bin_edges) - 1):\n",
1300 " eta_min, eta_max = eta_bin_edges[i], eta_bin_edges[i+1]\n",
1301 " bin_center = (eta_min + eta_max) / 2\n",
1302 " bin_centers.append(bin_center)\n",
1303 " \n",
1304 " # Get tracks in this eta bin\n",
1305 " eta_mask = (abs_eta >= eta_min) & (abs_eta < eta_max)\n",
1306 " \n",
1307 " # True tracks of this type in this bin\n",
1308 " true_type_mask = eta_mask & true_mask\n",
1309 " # Fake tracks in this bin\n",
1310 " fake_mask = eta_mask & fake_tracks\n",
1311 " \n",
1312 " print(f\"Eta bin {eta_min:.2f}-{eta_max:.2f}: {np.sum(fake_mask)} fakes, {np.sum(true_type_mask)} true {track_type}\")\n",
1313 " \n",
1314 " if np.sum(true_type_mask) > 0: # If we have true tracks in this bin\n",
1315 " for percentile in percentiles:\n",
1316 " # Calculate cut value to keep desired percentage of true tracks\n",
1317 " cut_value = np.percentile(probs[true_type_mask], 100 - percentile)\n",
1318 " cut_values[percentile].append(cut_value)\n",
1319 " \n",
1320 " # Calculate fake rejection for this cut\n",
1321 " if np.sum(fake_mask) > 0:\n",
1322 " fake_rej = 100 * np.mean(probs[fake_mask] < cut_value)\n",
1323 " fake_rejections[percentile].append(fake_rej)\n",
1324 " else:\n",
1325 " fake_rejections[percentile].append(np.nan)\n",
1326 " else:\n",
1327 " for percentile in percentiles:\n",
1328 " cut_values[percentile].append(np.nan)\n",
1329 " fake_rejections[percentile].append(np.nan)\n",
1330 " \n",
1331 " # Plot cut values and fake rejections\n",
1332 " colors = plt.cm.rainbow(np.linspace(0, 1, len(percentiles)))\n",
1333 " bin_centers = np.array(bin_centers)\n",
1334 " \n",
1335 " for (percentile, color) in zip(percentiles, colors):\n",
1336 " values = np.array(cut_values[percentile])\n",
1337 " mask = ~np.isnan(values)\n",
1338 " if np.any(mask):\n",
1339 " # Plot cut values\n",
1340 " ax1.plot(bin_centers[mask], values[mask], '-', color=color, marker='o',\n",
1341 " label=f'{percentile}% Retention Cut')\n",
1342 " # Plot fake rejections\n",
1343 " rej_values = np.array(fake_rejections[percentile])\n",
1344 " ax2.plot(bin_centers[mask], rej_values[mask], '-', color=color, marker='o',\n",
1345 " label=f'{percentile}% Cut')\n",
1346 " \n",
1347 " # Set plot labels and titles\n",
1348 " ax1.set_xlabel(\"Absolute Eta\")\n",
1349 " ax1.set_ylabel(f\"DNN {track_type} Probability\")\n",
1350 " ax1.set_title(f\"DNN Score vs Eta ({title_suffix})\\npt: {pt_min:.1f} to {pt_max:.1f} GeV\")\n",
1351 " ax1.legend()\n",
1352 " ax1.grid(True, alpha=0.3)\n",
1353 " \n",
1354 " ax2.set_xlabel(\"Absolute Eta\")\n",
1355 " ax2.set_ylabel(\"Fake Rejection (%)\")\n",
1356 " ax2.set_title(f\"Fake Rejection vs Eta\\npt: {pt_min:.1f} to {pt_max:.1f} GeV\")\n",
1357 " ax2.legend()\n",
1358 " ax2.grid(True, alpha=0.3)\n",
1359 " ax2.set_ylim(0, 100)\n",
1360 " \n",
1361 " plt.tight_layout()\n",
1362 " plt.show()\n",
1363 " \n",
1364 " # Print statistics\n",
1365 " print(f\"\\n{track_type} tracks, pt: {pt_min:.1f} to {pt_max:.1f} GeV\")\n",
1366 " print(f\"Number of true {track_type.lower()} tracks: {np.sum(true_mask)}\")\n",
1367 " print(f\"Number of fake tracks in pt bin: {np.sum(fake_tracks)}\")\n",
1368 " \n",
1369 " for percentile in percentiles:\n",
1370 " print(f\"\\n{percentile}% Retention Cut Values:\",\n",
1371 " '{' + ', '.join(f\"{x:.4f}\" if not np.isnan(x) else 'nan' for x in cut_values[percentile]) + '}',\n",
1372 " f\"Mean: {np.round(np.nanmean(cut_values[percentile]), 4)}\")\n",
1373 " print(f\"{percentile}% Cut Fake Rejections:\",\n",
1374 " '{' + ', '.join(f\"{x:.1f}\" if not np.isnan(x) else 'nan' for x in fake_rejections[percentile]) + '}',\n",
1375 " f\"Mean: {np.round(np.nanmean(fake_rejections[percentile]), 1)}%\")\n",
1376 "\n",
1377 "def analyze_pt_bins(pt_bins, percentiles, eta_bin_edges, t3_pt, predictions, t3_sim_vxy, eta_list):\n",
1378 " \"\"\"\n",
1379 " Analyze and plot for multiple pt bins and percentiles\n",
1380 " \"\"\"\n",
1381 " for i in range(len(pt_bins) - 1):\n",
1382 " plot_for_pt_bin(pt_bins[i], pt_bins[i + 1], percentiles, eta_bin_edges,\n",
1383 " t3_pt, predictions, t3_sim_vxy, eta_list)\n",
1384 "\n",
1385 "# Run the analysis with same parameters as before\n",
1386 "percentiles = [80, 90, 93, 96, 97, 98, 99, 99.5]\n",
1387 "pt_bins = [0, 5, np.inf]\n",
1388 "eta_bin_edges = np.arange(0, 2.75, 0.25)\n",
1389 "\n",
1390 "analyze_pt_bins(\n",
1391 " pt_bins=pt_bins,\n",
1392 " percentiles=percentiles,\n",
1393 " eta_bin_edges=eta_bin_edges,\n",
1394 " t3_pt=t3_pt,\n",
1395 " predictions=predictions,\n",
1396 " t3_sim_vxy=np.concatenate(branches['t3_sim_vxy']),\n",
1397 " eta_list=eta_list\n",
1398 ")"
1399 ]
1400 },
1401 {
1402 "cell_type": "code",
1403 "execution_count": null,
1404 "metadata": {},
1405 "outputs": [],
1406 "source": []
1407 }
1408 ],
1409 "metadata": {
1410 "kernelspec": {
1411 "display_name": "analysisenv",
1412 "language": "python",
1413 "name": "python3"
1414 },
1415 "language_info": {
1416 "codemirror_mode": {
1417 "name": "ipython",
1418 "version": 3
1419 },
1420 "file_extension": ".py",
1421 "mimetype": "text/x-python",
1422 "name": "python",
1423 "nbconvert_exporter": "python",
1424 "pygments_lexer": "ipython3",
1425 "version": "3.11.7"
1426 }
1427 },
1428 "nbformat": 4,
1429 "nbformat_minor": 2
1430 }