Warning, /RecoTracker/LSTCore/standalone/analysis/DNN/train_T3_DNN.ipynb is written in an unsupported language. File is not indexed.
0001 {
0002 "cells": [
0003 {
0004 "cell_type": "code",
0005 "execution_count": null,
0006 "metadata": {},
0007 "outputs": [],
0008 "source": [
0009 "# set seed for reproducibility\n",
0010 "import torch\n",
0011 "torch.manual_seed(42)"
0012 ]
0013 },
0014 {
0015 "cell_type": "code",
0016 "execution_count": 1,
0017 "metadata": {},
0018 "outputs": [],
0019 "source": [
0020 "import os\n",
0021 "import uproot\n",
0022 "import numpy as np\n",
0023 "\n",
0024 "def load_root_file(file_path, branches=None, print_branches=False):\n",
0025 " all_branches = {}\n",
0026 " with uproot.open(file_path) as file:\n",
0027 " tree = file[\"tree\"]\n",
0028 " # Load all ROOT branches into array if not specified\n",
0029 " if branches is None:\n",
0030 " branches = tree.keys()\n",
0031 " # Option to print the branch names\n",
0032 " if print_branches:\n",
0033 " print(\"Branches:\", tree.keys())\n",
0034 " # Each branch is added to the dictionary\n",
0035 " for branch in branches:\n",
0036 " try:\n",
0037 " all_branches[branch] = (tree[branch].array(library=\"np\"))\n",
0038 " except uproot.KeyInFileError as e:\n",
0039 " print(f\"KeyInFileError: {e}\")\n",
0040 " # Number of events in file\n",
0041 " all_branches['event'] = tree.num_entries\n",
0042 " return all_branches\n",
0043 "\n",
0044 "branches_list = [\n",
0045 " # Core T3 properties from TripletsSoA\n",
0046 " 't3_betaIn',\n",
0047 " 't3_centerX',\n",
0048 " 't3_centerY',\n",
0049 " 't3_radius',\n",
0050 " 't3_partOfPT5',\n",
0051 " 't3_partOfT5',\n",
0052 " 't3_partOfPT3',\n",
0053 " 't3_layer_binary',\n",
0054 " 't3_pMatched',\n",
0055 " 't3_matched_simIdx',\n",
0056 " 't3_sim_vxy',\n",
0057 " 't3_sim_vz'\n",
0058 "]\n",
0059 "\n",
0060 "# Hit-dependent branches\n",
0061 "suffixes = ['r', 'z', 'eta', 'phi', 'layer']\n",
0062 "branches_list += [f't3_hit_{i}_{suffix}' for i in [0, 1, 2, 3, 4, 5] for suffix in suffixes]\n",
0063 "\n",
0064 "file_path = \"600_t3_dnn_relval_fix.root\"\n",
0065 "branches = load_root_file(file_path, branches_list)"
0066 ]
0067 },
0068 {
0069 "cell_type": "code",
0070 "execution_count": 2,
0071 "metadata": {},
0072 "outputs": [
0073 {
0074 "name": "stdout",
0075 "output_type": "stream",
0076 "text": [
0077 "Z max: 224.14950561523438, R max: 98.93299102783203, Eta max: 2.5\n"
0078 ]
0079 }
0080 ],
0081 "source": [
0082 "z_max = np.max([np.max(event) for event in branches[f't3_hit_3_z']])\n",
0083 "r_max = np.max([np.max(event) for event in branches[f't3_hit_3_r']])\n",
0084 "eta_max = 2.5\n",
0085 "phi_max = np.pi\n",
0086 "\n",
0087 "print(f'Z max: {z_max}, R max: {r_max}, Eta max: {eta_max}')"
0088 ]
0089 },
0090 {
0091 "cell_type": "code",
0092 "execution_count": 3,
0093 "metadata": {},
0094 "outputs": [],
0095 "source": [
0096 "def delta_phi(phi1, phi2):\n",
0097 " delta = phi1 - phi2\n",
0098 " delta = np.where(delta > np.pi, delta - 2*np.pi, delta)\n",
0099 " delta = np.where(delta < -np.pi, delta + 2*np.pi, delta)\n",
0100 " return delta\n",
0101 "\n",
0102 "n_events = branches['event']\n",
0103 "\n",
0104 "all_eta0 = np.abs(np.concatenate([branches['t3_hit_0_eta'][evt] for evt in range(n_events)]))\n",
0105 "all_eta2 = np.abs(np.concatenate([branches['t3_hit_2_eta'][evt] for evt in range(n_events)]))\n",
0106 "all_eta4 = np.abs(np.concatenate([branches['t3_hit_4_eta'][evt] for evt in range(n_events)]))\n",
0107 "\n",
0108 "all_phi0 = np.concatenate([branches['t3_hit_0_phi'][evt] for evt in range(n_events)])\n",
0109 "all_phi2 = np.concatenate([branches['t3_hit_2_phi'][evt] for evt in range(n_events)])\n",
0110 "all_phi4 = np.concatenate([branches['t3_hit_4_phi'][evt] for evt in range(n_events)])\n",
0111 "\n",
0112 "all_z0 = np.abs(np.concatenate([branches['t3_hit_0_z'][evt] for evt in range(n_events)]))\n",
0113 "all_z2 = np.abs(np.concatenate([branches['t3_hit_2_z'][evt] for evt in range(n_events)]))\n",
0114 "all_z4 = np.abs(np.concatenate([branches['t3_hit_4_z'][evt] for evt in range(n_events)]))\n",
0115 "\n",
0116 "all_r0 = np.concatenate([branches['t3_hit_0_r'][evt] for evt in range(n_events)])\n",
0117 "all_r2 = np.concatenate([branches['t3_hit_2_r'][evt] for evt in range(n_events)])\n",
0118 "all_r4 = np.concatenate([branches['t3_hit_4_r'][evt] for evt in range(n_events)])\n",
0119 "\n",
0120 "all_radius = np.concatenate([branches['t3_radius'][evt] for evt in range(n_events)])\n",
0121 "all_betaIn = np.concatenate([branches['t3_betaIn'][evt] for evt in range(n_events)])\n",
0122 "\n",
0123 "features = np.array([\n",
0124 " all_eta0 / eta_max, # Hit 0 eta\n",
0125 " np.abs(all_phi0) / phi_max, # Hit 0 phi\n",
0126 " all_z0 / z_max, # Hit 0 z\n",
0127 " all_r0 / r_max, # Hit 0 r\n",
0128 " (all_eta2 - all_eta0), # Difference in eta: hit2 - hit0\n",
0129 " delta_phi(all_phi2, all_phi0) / phi_max, # Difference in phi: hit2 - hit0\n",
0130 " (all_z2 - all_z0) / z_max, # Difference in z: hit2 - hit0\n",
0131 " (all_r2 - all_r0) / r_max, # Difference in r: hit2 - hit0\n",
0132 " (all_eta4 - all_eta2), # Difference in eta: hit4 - hit2\n",
0133 " delta_phi(all_phi4, all_phi2) / phi_max, # Difference in phi: hit4 - hit2\n",
0134 " (all_z4 - all_z2) / z_max, # Difference in z: hit4 - hit2\n",
0135 " (all_r4 - all_r2) / r_max, # Difference in r: hit4 - hit2\n",
0136 " np.log10(all_radius), # Circle radius\n",
0137 " all_betaIn # Beta angle\n",
0138 "])\n",
0139 "\n",
0140 "eta_list = np.array([all_eta0])"
0141 ]
0142 },
0143 {
0144 "cell_type": "code",
0145 "execution_count": 4,
0146 "metadata": {},
0147 "outputs": [],
0148 "source": [
0149 "import torch\n",
0150 "from torch import nn\n",
0151 "from torch.optim import Adam\n",
0152 "from torch.utils.data import DataLoader, TensorDataset, random_split\n",
0153 "import numpy as np\n",
0154 "\n",
0155 "# Stack features and handle NaN/Inf as before\n",
0156 "input_features_numpy = np.stack(features, axis=-1)\n",
0157 "mask = ~np.isnan(input_features_numpy) & ~np.isinf(input_features_numpy)\n",
0158 "filtered_input_features_numpy = input_features_numpy[np.all(mask, axis=1)]\n",
0159 "t3_isFake_filtered = (np.concatenate(branches['t3_pMatched']) < 0.75)[np.all(mask, axis=1)]\n",
0160 "t3_sim_vxy_filtered = np.concatenate(branches['t3_sim_vxy'])[np.all(mask, axis=1)]\n",
0161 "\n",
0162 "# Convert to PyTorch tensor\n",
0163 "input_features_tensor = torch.tensor(filtered_input_features_numpy, dtype=torch.float32)"
0164 ]
0165 },
0166 {
0167 "cell_type": "code",
0168 "execution_count": 5,
0169 "metadata": {},
0170 "outputs": [
0171 {
0172 "name": "stdout",
0173 "output_type": "stream",
0174 "text": [
0175 "Using device: cuda\n",
0176 "Initial dataset size: 55072926\n",
0177 "Class distribution before downsampling - Fake: 49829032.0, Prompt: 4472777.0, Displaced: 771119.0\n",
0178 "Class distribution after downsampling - Fake: 9965806.0, Prompt: 4472777.0, Displaced: 771119.0\n",
0179 "Epoch [1/400], Train Loss: 0.6515, Test Loss: 0.5908\n",
0180 "Epoch [2/400], Train Loss: 0.5771, Test Loss: 0.5659\n",
0181 "Epoch [3/400], Train Loss: 0.5647, Test Loss: 0.5549\n",
0182 "Epoch [4/400], Train Loss: 0.5588, Test Loss: 0.5627\n",
0183 "Epoch [5/400], Train Loss: 0.5553, Test Loss: 0.5541\n",
0184 "Epoch [6/400], Train Loss: 0.5528, Test Loss: 0.5664\n",
0185 "Epoch [7/400], Train Loss: 0.5508, Test Loss: 0.5574\n",
0186 "Epoch [8/400], Train Loss: 0.5492, Test Loss: 0.5503\n",
0187 "Epoch [9/400], Train Loss: 0.5478, Test Loss: 0.5447\n",
0188 "Epoch [10/400], Train Loss: 0.5472, Test Loss: 0.5538\n",
0189 "Epoch [11/400], Train Loss: 0.5459, Test Loss: 0.5534\n",
0190 "Epoch [12/400], Train Loss: 0.5454, Test Loss: 0.5487\n",
0191 "Epoch [13/400], Train Loss: 0.5445, Test Loss: 0.5366\n",
0192 "Epoch [14/400], Train Loss: 0.5441, Test Loss: 0.5387\n",
0193 "Epoch [15/400], Train Loss: 0.5434, Test Loss: 0.5421\n",
0194 "Epoch [16/400], Train Loss: 0.5426, Test Loss: 0.5397\n",
0195 "Epoch [17/400], Train Loss: 0.5420, Test Loss: 0.5486\n",
0196 "Epoch [18/400], Train Loss: 0.5412, Test Loss: 0.5398\n",
0197 "Epoch [19/400], Train Loss: 0.5409, Test Loss: 0.5421\n",
0198 "Epoch [20/400], Train Loss: 0.5405, Test Loss: 0.5499\n",
0199 "Epoch [21/400], Train Loss: 0.5399, Test Loss: 0.5573\n",
0200 "Epoch [22/400], Train Loss: 0.5396, Test Loss: 0.5388\n",
0201 "Epoch [23/400], Train Loss: 0.5393, Test Loss: 0.5399\n",
0202 "Epoch [24/400], Train Loss: 0.5388, Test Loss: 0.5391\n",
0203 "Epoch [25/400], Train Loss: 0.5383, Test Loss: 0.5375\n",
0204 "Epoch [26/400], Train Loss: 0.5381, Test Loss: 0.5386\n",
0205 "Epoch [27/400], Train Loss: 0.5380, Test Loss: 0.5454\n",
0206 "Epoch [28/400], Train Loss: 0.5376, Test Loss: 0.5350\n",
0207 "Epoch [29/400], Train Loss: 0.5376, Test Loss: 0.5572\n",
0208 "Epoch [30/400], Train Loss: 0.5371, Test Loss: 0.5486\n",
0209 "Epoch [31/400], Train Loss: 0.5364, Test Loss: 0.5331\n",
0210 "Epoch [32/400], Train Loss: 0.5368, Test Loss: 0.5498\n",
0211 "Epoch [33/400], Train Loss: 0.5367, Test Loss: 0.5363\n",
0212 "Epoch [34/400], Train Loss: 0.5363, Test Loss: 0.5377\n",
0213 "Epoch [35/400], Train Loss: 0.5360, Test Loss: 0.5413\n",
0214 "Epoch [36/400], Train Loss: 0.5361, Test Loss: 0.5385\n",
0215 "Epoch [37/400], Train Loss: 0.5357, Test Loss: 0.5433\n",
0216 "Epoch [38/400], Train Loss: 0.5353, Test Loss: 0.5404\n",
0217 "Epoch [39/400], Train Loss: 0.5353, Test Loss: 0.5328\n",
0218 "Epoch [40/400], Train Loss: 0.5349, Test Loss: 0.5363\n",
0219 "Epoch [41/400], Train Loss: 0.5348, Test Loss: 0.5371\n",
0220 "Epoch [42/400], Train Loss: 0.5347, Test Loss: 0.5349\n",
0221 "Epoch [43/400], Train Loss: 0.5344, Test Loss: 0.5356\n",
0222 "Epoch [44/400], Train Loss: 0.5342, Test Loss: 0.5355\n",
0223 "Epoch [45/400], Train Loss: 0.5338, Test Loss: 0.5346\n",
0224 "Epoch [46/400], Train Loss: 0.5337, Test Loss: 0.5345\n",
0225 "Epoch [47/400], Train Loss: 0.5336, Test Loss: 0.5323\n",
0226 "Epoch [48/400], Train Loss: 0.5333, Test Loss: 0.5298\n",
0227 "Epoch [49/400], Train Loss: 0.5332, Test Loss: 0.5388\n",
0228 "Epoch [50/400], Train Loss: 0.5331, Test Loss: 0.5312\n",
0229 "Epoch [51/400], Train Loss: 0.5329, Test Loss: 0.5305\n",
0230 "Epoch [52/400], Train Loss: 0.5328, Test Loss: 0.5325\n",
0231 "Epoch [53/400], Train Loss: 0.5325, Test Loss: 0.5333\n",
0232 "Epoch [54/400], Train Loss: 0.5325, Test Loss: 0.5285\n",
0233 "Epoch [55/400], Train Loss: 0.5325, Test Loss: 0.5400\n",
0234 "Epoch [56/400], Train Loss: 0.5323, Test Loss: 0.5324\n",
0235 "Epoch [57/400], Train Loss: 0.5320, Test Loss: 0.5298\n",
0236 "Epoch [58/400], Train Loss: 0.5319, Test Loss: 0.5408\n",
0237 "Epoch [59/400], Train Loss: 0.5319, Test Loss: 0.5294\n",
0238 "Epoch [60/400], Train Loss: 0.5315, Test Loss: 0.5293\n",
0239 "Epoch [61/400], Train Loss: 0.5316, Test Loss: 0.5381\n",
0240 "Epoch [62/400], Train Loss: 0.5315, Test Loss: 0.5302\n",
0241 "Epoch [63/400], Train Loss: 0.5316, Test Loss: 0.5329\n",
0242 "Epoch [64/400], Train Loss: 0.5313, Test Loss: 0.5341\n",
0243 "Epoch [65/400], Train Loss: 0.5311, Test Loss: 0.5333\n",
0244 "Epoch [66/400], Train Loss: 0.5311, Test Loss: 0.5379\n",
0245 "Epoch [67/400], Train Loss: 0.5310, Test Loss: 0.5314\n",
0246 "Epoch [68/400], Train Loss: 0.5308, Test Loss: 0.5377\n",
0247 "Epoch [69/400], Train Loss: 0.5310, Test Loss: 0.5325\n",
0248 "Epoch [70/400], Train Loss: 0.5307, Test Loss: 0.5307\n",
0249 "Epoch [71/400], Train Loss: 0.5305, Test Loss: 0.5321\n",
0250 "Epoch [72/400], Train Loss: 0.5304, Test Loss: 0.5328\n",
0251 "Epoch [73/400], Train Loss: 0.5304, Test Loss: 0.5355\n",
0252 "Epoch [74/400], Train Loss: 0.5301, Test Loss: 0.5298\n",
0253 "Epoch [75/400], Train Loss: 0.5300, Test Loss: 0.5363\n",
0254 "Epoch [76/400], Train Loss: 0.5303, Test Loss: 0.5331\n",
0255 "Epoch [77/400], Train Loss: 0.5298, Test Loss: 0.5270\n",
0256 "Epoch [78/400], Train Loss: 0.5300, Test Loss: 0.5324\n",
0257 "Epoch [79/400], Train Loss: 0.5300, Test Loss: 0.5336\n",
0258 "Epoch [80/400], Train Loss: 0.5297, Test Loss: 0.5283\n",
0259 "Epoch [81/400], Train Loss: 0.5297, Test Loss: 0.5285\n",
0260 "Epoch [82/400], Train Loss: 0.5295, Test Loss: 0.5286\n",
0261 "Epoch [83/400], Train Loss: 0.5295, Test Loss: 0.5277\n",
0262 "Epoch [84/400], Train Loss: 0.5294, Test Loss: 0.5300\n",
0263 "Epoch [85/400], Train Loss: 0.5295, Test Loss: 0.5317\n",
0264 "Epoch [86/400], Train Loss: 0.5292, Test Loss: 0.5288\n",
0265 "Epoch [87/400], Train Loss: 0.5293, Test Loss: 0.5295\n",
0266 "Epoch [88/400], Train Loss: 0.5291, Test Loss: 0.5273\n",
0267 "Epoch [89/400], Train Loss: 0.5292, Test Loss: 0.5289\n",
0268 "Epoch [90/400], Train Loss: 0.5292, Test Loss: 0.5273\n",
0269 "Epoch [91/400], Train Loss: 0.5289, Test Loss: 0.5370\n",
0270 "Epoch [92/400], Train Loss: 0.5288, Test Loss: 0.5263\n",
0271 "Epoch [93/400], Train Loss: 0.5288, Test Loss: 0.5338\n",
0272 "Epoch [94/400], Train Loss: 0.5287, Test Loss: 0.5326\n",
0273 "Epoch [95/400], Train Loss: 0.5286, Test Loss: 0.5300\n",
0274 "Epoch [96/400], Train Loss: 0.5286, Test Loss: 0.5280\n",
0275 "Epoch [97/400], Train Loss: 0.5284, Test Loss: 0.5291\n",
0276 "Epoch [98/400], Train Loss: 0.5284, Test Loss: 0.5310\n",
0277 "Epoch [99/400], Train Loss: 0.5283, Test Loss: 0.5275\n",
0278 "Epoch [100/400], Train Loss: 0.5280, Test Loss: 0.5294\n",
0279 "Epoch [101/400], Train Loss: 0.5279, Test Loss: 0.5290\n",
0280 "Epoch [102/400], Train Loss: 0.5273, Test Loss: 0.5264\n",
0281 "Epoch [103/400], Train Loss: 0.5270, Test Loss: 0.5289\n",
0282 "Epoch [104/400], Train Loss: 0.5265, Test Loss: 0.5294\n",
0283 "Epoch [105/400], Train Loss: 0.5263, Test Loss: 0.5290\n",
0284 "Epoch [106/400], Train Loss: 0.5261, Test Loss: 0.5302\n",
0285 "Epoch [107/400], Train Loss: 0.5258, Test Loss: 0.5309\n",
0286 "Epoch [108/400], Train Loss: 0.5258, Test Loss: 0.5254\n",
0287 "Epoch [109/400], Train Loss: 0.5256, Test Loss: 0.5234\n",
0288 "Epoch [110/400], Train Loss: 0.5255, Test Loss: 0.5307\n",
0289 "Epoch [111/400], Train Loss: 0.5256, Test Loss: 0.5250\n",
0290 "Epoch [112/400], Train Loss: 0.5252, Test Loss: 0.5300\n",
0291 "Epoch [113/400], Train Loss: 0.5253, Test Loss: 0.5328\n",
0292 "Epoch [114/400], Train Loss: 0.5252, Test Loss: 0.5347\n",
0293 "Epoch [115/400], Train Loss: 0.5251, Test Loss: 0.5263\n",
0294 "Epoch [116/400], Train Loss: 0.5250, Test Loss: 0.5312\n",
0295 "Epoch [117/400], Train Loss: 0.5250, Test Loss: 0.5313\n",
0296 "Epoch [118/400], Train Loss: 0.5248, Test Loss: 0.5291\n",
0297 "Epoch [119/400], Train Loss: 0.5249, Test Loss: 0.5314\n",
0298 "Epoch [120/400], Train Loss: 0.5249, Test Loss: 0.5246\n",
0299 "Epoch [121/400], Train Loss: 0.5246, Test Loss: 0.5271\n",
0300 "Epoch [122/400], Train Loss: 0.5244, Test Loss: 0.5286\n",
0301 "Epoch [123/400], Train Loss: 0.5241, Test Loss: 0.5361\n",
0302 "Epoch [124/400], Train Loss: 0.5240, Test Loss: 0.5229\n",
0303 "Epoch [125/400], Train Loss: 0.5239, Test Loss: 0.5268\n",
0304 "Epoch [126/400], Train Loss: 0.5239, Test Loss: 0.5233\n",
0305 "Epoch [127/400], Train Loss: 0.5238, Test Loss: 0.5254\n",
0306 "Epoch [128/400], Train Loss: 0.5236, Test Loss: 0.5271\n",
0307 "Epoch [129/400], Train Loss: 0.5235, Test Loss: 0.5219\n",
0308 "Epoch [130/400], Train Loss: 0.5234, Test Loss: 0.5273\n",
0309 "Epoch [131/400], Train Loss: 0.5232, Test Loss: 0.5241\n",
0310 "Epoch [132/400], Train Loss: 0.5230, Test Loss: 0.5234\n",
0311 "Epoch [133/400], Train Loss: 0.5229, Test Loss: 0.5232\n",
0312 "Epoch [134/400], Train Loss: 0.5229, Test Loss: 0.5288\n",
0313 "Epoch [135/400], Train Loss: 0.5229, Test Loss: 0.5261\n",
0314 "Epoch [136/400], Train Loss: 0.5230, Test Loss: 0.5271\n",
0315 "Epoch [137/400], Train Loss: 0.5227, Test Loss: 0.5287\n",
0316 "Epoch [138/400], Train Loss: 0.5228, Test Loss: 0.5216\n",
0317 "Epoch [139/400], Train Loss: 0.5227, Test Loss: 0.5263\n",
0318 "Epoch [140/400], Train Loss: 0.5224, Test Loss: 0.5274\n",
0319 "Epoch [141/400], Train Loss: 0.5225, Test Loss: 0.5234\n",
0320 "Epoch [142/400], Train Loss: 0.5226, Test Loss: 0.5251\n",
0321 "Epoch [143/400], Train Loss: 0.5221, Test Loss: 0.5224\n",
0322 "Epoch [144/400], Train Loss: 0.5223, Test Loss: 0.5222\n",
0323 "Epoch [145/400], Train Loss: 0.5224, Test Loss: 0.5275\n",
0324 "Epoch [146/400], Train Loss: 0.5223, Test Loss: 0.5203\n",
0325 "Epoch [147/400], Train Loss: 0.5223, Test Loss: 0.5218\n",
0326 "Epoch [148/400], Train Loss: 0.5222, Test Loss: 0.5256\n",
0327 "Epoch [149/400], Train Loss: 0.5221, Test Loss: 0.5227\n",
0328 "Epoch [150/400], Train Loss: 0.5219, Test Loss: 0.5210\n",
0329 "Epoch [151/400], Train Loss: 0.5221, Test Loss: 0.5239\n",
0330 "Epoch [152/400], Train Loss: 0.5221, Test Loss: 0.5218\n",
0331 "Epoch [153/400], Train Loss: 0.5219, Test Loss: 0.5305\n",
0332 "Epoch [154/400], Train Loss: 0.5219, Test Loss: 0.5248\n",
0333 "Epoch [155/400], Train Loss: 0.5218, Test Loss: 0.5247\n",
0334 "Epoch [156/400], Train Loss: 0.5218, Test Loss: 0.5222\n",
0335 "Epoch [157/400], Train Loss: 0.5216, Test Loss: 0.5332\n",
0336 "Epoch [158/400], Train Loss: 0.5217, Test Loss: 0.5230\n",
0337 "Epoch [159/400], Train Loss: 0.5217, Test Loss: 0.5237\n",
0338 "Epoch [160/400], Train Loss: 0.5216, Test Loss: 0.5205\n",
0339 "Epoch [161/400], Train Loss: 0.5215, Test Loss: 0.5208\n",
0340 "Epoch [162/400], Train Loss: 0.5216, Test Loss: 0.5242\n",
0341 "Epoch [163/400], Train Loss: 0.5216, Test Loss: 0.5254\n",
0342 "Epoch [164/400], Train Loss: 0.5214, Test Loss: 0.5229\n",
0343 "Epoch [165/400], Train Loss: 0.5214, Test Loss: 0.5260\n",
0344 "Epoch [166/400], Train Loss: 0.5213, Test Loss: 0.5193\n",
0345 "Epoch [167/400], Train Loss: 0.5212, Test Loss: 0.5225\n",
0346 "Epoch [168/400], Train Loss: 0.5211, Test Loss: 0.5240\n",
0347 "Epoch [169/400], Train Loss: 0.5213, Test Loss: 0.5220\n",
0348 "Epoch [170/400], Train Loss: 0.5213, Test Loss: 0.5276\n",
0349 "Epoch [171/400], Train Loss: 0.5211, Test Loss: 0.5203\n",
0350 "Epoch [172/400], Train Loss: 0.5214, Test Loss: 0.5202\n",
0351 "Epoch [173/400], Train Loss: 0.5210, Test Loss: 0.5213\n",
0352 "Epoch [174/400], Train Loss: 0.5212, Test Loss: 0.5215\n",
0353 "Epoch [175/400], Train Loss: 0.5211, Test Loss: 0.5242\n",
0354 "Epoch [176/400], Train Loss: 0.5210, Test Loss: 0.5217\n",
0355 "Epoch [177/400], Train Loss: 0.5209, Test Loss: 0.5231\n",
0356 "Epoch [178/400], Train Loss: 0.5210, Test Loss: 0.5225\n",
0357 "Epoch [179/400], Train Loss: 0.5210, Test Loss: 0.5229\n",
0358 "Epoch [180/400], Train Loss: 0.5208, Test Loss: 0.5235\n",
0359 "Epoch [181/400], Train Loss: 0.5207, Test Loss: 0.5244\n",
0360 "Epoch [182/400], Train Loss: 0.5208, Test Loss: 0.5224\n",
0361 "Epoch [183/400], Train Loss: 0.5209, Test Loss: 0.5264\n",
0362 "Epoch [184/400], Train Loss: 0.5207, Test Loss: 0.5220\n",
0363 "Epoch [185/400], Train Loss: 0.5206, Test Loss: 0.5202\n",
0364 "Epoch [186/400], Train Loss: 0.5208, Test Loss: 0.5187\n",
0365 "Epoch [187/400], Train Loss: 0.5206, Test Loss: 0.5270\n",
0366 "Epoch [188/400], Train Loss: 0.5207, Test Loss: 0.5196\n",
0367 "Epoch [189/400], Train Loss: 0.5205, Test Loss: 0.5270\n",
0368 "Epoch [190/400], Train Loss: 0.5207, Test Loss: 0.5241\n",
0369 "Epoch [191/400], Train Loss: 0.5206, Test Loss: 0.5226\n",
0370 "Epoch [192/400], Train Loss: 0.5205, Test Loss: 0.5289\n",
0371 "Epoch [193/400], Train Loss: 0.5205, Test Loss: 0.5204\n",
0372 "Epoch [194/400], Train Loss: 0.5204, Test Loss: 0.5215\n",
0373 "Epoch [195/400], Train Loss: 0.5205, Test Loss: 0.5205\n",
0374 "Epoch [196/400], Train Loss: 0.5204, Test Loss: 0.5236\n",
0375 "Epoch [197/400], Train Loss: 0.5205, Test Loss: 0.5209\n",
0376 "Epoch [198/400], Train Loss: 0.5202, Test Loss: 0.5225\n",
0377 "Epoch [199/400], Train Loss: 0.5204, Test Loss: 0.5219\n",
0378 "Epoch [200/400], Train Loss: 0.5203, Test Loss: 0.5217\n",
0379 "Epoch [201/400], Train Loss: 0.5204, Test Loss: 0.5237\n",
0380 "Epoch [202/400], Train Loss: 0.5201, Test Loss: 0.5186\n",
0381 "Epoch [203/400], Train Loss: 0.5203, Test Loss: 0.5228\n",
0382 "Epoch [204/400], Train Loss: 0.5202, Test Loss: 0.5213\n",
0383 "Epoch [205/400], Train Loss: 0.5200, Test Loss: 0.5197\n",
0384 "Epoch [206/400], Train Loss: 0.5202, Test Loss: 0.5209\n",
0385 "Epoch [207/400], Train Loss: 0.5200, Test Loss: 0.5250\n",
0386 "Epoch [208/400], Train Loss: 0.5203, Test Loss: 0.5183\n",
0387 "Epoch [209/400], Train Loss: 0.5201, Test Loss: 0.5181\n",
0388 "Epoch [210/400], Train Loss: 0.5200, Test Loss: 0.5235\n",
0389 "Epoch [211/400], Train Loss: 0.5201, Test Loss: 0.5209\n",
0390 "Epoch [212/400], Train Loss: 0.5200, Test Loss: 0.5203\n",
0391 "Epoch [213/400], Train Loss: 0.5202, Test Loss: 0.5235\n",
0392 "Epoch [214/400], Train Loss: 0.5201, Test Loss: 0.5184\n",
0393 "Epoch [215/400], Train Loss: 0.5199, Test Loss: 0.5275\n",
0394 "Epoch [216/400], Train Loss: 0.5199, Test Loss: 0.5200\n",
0395 "Epoch [217/400], Train Loss: 0.5199, Test Loss: 0.5216\n",
0396 "Epoch [218/400], Train Loss: 0.5199, Test Loss: 0.5230\n",
0397 "Epoch [219/400], Train Loss: 0.5200, Test Loss: 0.5193\n",
0398 "Epoch [220/400], Train Loss: 0.5199, Test Loss: 0.5217\n",
0399 "Epoch [221/400], Train Loss: 0.5200, Test Loss: 0.5234\n",
0400 "Epoch [222/400], Train Loss: 0.5197, Test Loss: 0.5226\n",
0401 "Epoch [223/400], Train Loss: 0.5198, Test Loss: 0.5242\n",
0402 "Epoch [224/400], Train Loss: 0.5198, Test Loss: 0.5226\n",
0403 "Epoch [225/400], Train Loss: 0.5199, Test Loss: 0.5172\n",
0404 "Epoch [226/400], Train Loss: 0.5197, Test Loss: 0.5206\n",
0405 "Epoch [227/400], Train Loss: 0.5197, Test Loss: 0.5211\n",
0406 "Epoch [228/400], Train Loss: 0.5197, Test Loss: 0.5199\n",
0407 "Epoch [229/400], Train Loss: 0.5197, Test Loss: 0.5194\n",
0408 "Epoch [230/400], Train Loss: 0.5197, Test Loss: 0.5212\n",
0409 "Epoch [231/400], Train Loss: 0.5197, Test Loss: 0.5235\n",
0410 "Epoch [232/400], Train Loss: 0.5199, Test Loss: 0.5180\n",
0411 "Epoch [233/400], Train Loss: 0.5197, Test Loss: 0.5186\n",
0412 "Epoch [234/400], Train Loss: 0.5198, Test Loss: 0.5192\n",
0413 "Epoch [235/400], Train Loss: 0.5197, Test Loss: 0.5232\n",
0414 "Epoch [236/400], Train Loss: 0.5195, Test Loss: 0.5267\n",
0415 "Epoch [237/400], Train Loss: 0.5198, Test Loss: 0.5203\n",
0416 "Epoch [238/400], Train Loss: 0.5196, Test Loss: 0.5195\n",
0417 "Epoch [239/400], Train Loss: 0.5197, Test Loss: 0.5210\n",
0418 "Epoch [240/400], Train Loss: 0.5196, Test Loss: 0.5273\n",
0419 "Epoch [241/400], Train Loss: 0.5195, Test Loss: 0.5247\n",
0420 "Epoch [242/400], Train Loss: 0.5197, Test Loss: 0.5184\n",
0421 "Epoch [243/400], Train Loss: 0.5196, Test Loss: 0.5188\n",
0422 "Epoch [244/400], Train Loss: 0.5198, Test Loss: 0.5201\n",
0423 "Epoch [245/400], Train Loss: 0.5195, Test Loss: 0.5242\n",
0424 "Epoch [246/400], Train Loss: 0.5196, Test Loss: 0.5204\n",
0425 "Epoch [247/400], Train Loss: 0.5196, Test Loss: 0.5232\n",
0426 "Epoch [248/400], Train Loss: 0.5194, Test Loss: 0.5268\n",
0427 "Epoch [249/400], Train Loss: 0.5196, Test Loss: 0.5205\n",
0428 "Epoch [250/400], Train Loss: 0.5195, Test Loss: 0.5255\n",
0429 "Epoch [251/400], Train Loss: 0.5195, Test Loss: 0.5211\n",
0430 "Epoch [252/400], Train Loss: 0.5195, Test Loss: 0.5200\n",
0431 "Epoch [253/400], Train Loss: 0.5196, Test Loss: 0.5217\n",
0432 "Epoch [254/400], Train Loss: 0.5196, Test Loss: 0.5208\n",
0433 "Epoch [255/400], Train Loss: 0.5194, Test Loss: 0.5209\n",
0434 "Epoch [256/400], Train Loss: 0.5195, Test Loss: 0.5252\n",
0435 "Epoch [257/400], Train Loss: 0.5194, Test Loss: 0.5209\n",
0436 "Epoch [258/400], Train Loss: 0.5195, Test Loss: 0.5247\n",
0437 "Epoch [259/400], Train Loss: 0.5196, Test Loss: 0.5201\n",
0438 "Epoch [260/400], Train Loss: 0.5192, Test Loss: 0.5191\n",
0439 "Epoch [261/400], Train Loss: 0.5194, Test Loss: 0.5202\n",
0440 "Epoch [262/400], Train Loss: 0.5194, Test Loss: 0.5209\n",
0441 "Epoch [263/400], Train Loss: 0.5193, Test Loss: 0.5258\n",
0442 "Epoch [264/400], Train Loss: 0.5194, Test Loss: 0.5226\n",
0443 "Epoch [265/400], Train Loss: 0.5195, Test Loss: 0.5263\n",
0444 "Epoch [266/400], Train Loss: 0.5195, Test Loss: 0.5223\n",
0445 "Epoch [267/400], Train Loss: 0.5193, Test Loss: 0.5236\n",
0446 "Epoch [268/400], Train Loss: 0.5192, Test Loss: 0.5274\n",
0447 "Epoch [269/400], Train Loss: 0.5194, Test Loss: 0.5193\n",
0448 "Epoch [270/400], Train Loss: 0.5191, Test Loss: 0.5189\n",
0449 "Epoch [271/400], Train Loss: 0.5193, Test Loss: 0.5257\n",
0450 "Epoch [272/400], Train Loss: 0.5193, Test Loss: 0.5191\n",
0451 "Epoch [273/400], Train Loss: 0.5194, Test Loss: 0.5192\n",
0452 "Epoch [274/400], Train Loss: 0.5193, Test Loss: 0.5215\n",
0453 "Epoch [275/400], Train Loss: 0.5192, Test Loss: 0.5199\n",
0454 "Epoch [276/400], Train Loss: 0.5193, Test Loss: 0.5231\n",
0455 "Epoch [277/400], Train Loss: 0.5192, Test Loss: 0.5210\n",
0456 "Epoch [278/400], Train Loss: 0.5192, Test Loss: 0.5203\n",
0457 "Epoch [279/400], Train Loss: 0.5194, Test Loss: 0.5225\n",
0458 "Epoch [280/400], Train Loss: 0.5193, Test Loss: 0.5242\n",
0459 "Epoch [281/400], Train Loss: 0.5192, Test Loss: 0.5270\n",
0460 "Epoch [282/400], Train Loss: 0.5192, Test Loss: 0.5226\n",
0461 "Epoch [283/400], Train Loss: 0.5192, Test Loss: 0.5221\n",
0462 "Epoch [284/400], Train Loss: 0.5193, Test Loss: 0.5171\n",
0463 "Epoch [285/400], Train Loss: 0.5191, Test Loss: 0.5211\n",
0464 "Epoch [286/400], Train Loss: 0.5191, Test Loss: 0.5178\n",
0465 "Epoch [287/400], Train Loss: 0.5190, Test Loss: 0.5173\n",
0466 "Epoch [288/400], Train Loss: 0.5192, Test Loss: 0.5277\n",
0467 "Epoch [289/400], Train Loss: 0.5190, Test Loss: 0.5196\n",
0468 "Epoch [290/400], Train Loss: 0.5192, Test Loss: 0.5200\n",
0469 "Epoch [291/400], Train Loss: 0.5190, Test Loss: 0.5186\n",
0470 "Epoch [292/400], Train Loss: 0.5192, Test Loss: 0.5211\n",
0471 "Epoch [293/400], Train Loss: 0.5192, Test Loss: 0.5249\n",
0472 "Epoch [294/400], Train Loss: 0.5191, Test Loss: 0.5196\n",
0473 "Epoch [295/400], Train Loss: 0.5191, Test Loss: 0.5215\n",
0474 "Epoch [296/400], Train Loss: 0.5192, Test Loss: 0.5223\n",
0475 "Epoch [297/400], Train Loss: 0.5192, Test Loss: 0.5233\n",
0476 "Epoch [298/400], Train Loss: 0.5191, Test Loss: 0.5223\n",
0477 "Epoch [299/400], Train Loss: 0.5189, Test Loss: 0.5212\n",
0478 "Epoch [300/400], Train Loss: 0.5189, Test Loss: 0.5199\n",
0479 "Epoch [301/400], Train Loss: 0.5190, Test Loss: 0.5197\n",
0480 "Epoch [302/400], Train Loss: 0.5190, Test Loss: 0.5289\n",
0481 "Epoch [303/400], Train Loss: 0.5189, Test Loss: 0.5220\n",
0482 "Epoch [304/400], Train Loss: 0.5190, Test Loss: 0.5296\n",
0483 "Epoch [305/400], Train Loss: 0.5189, Test Loss: 0.5185\n",
0484 "Epoch [306/400], Train Loss: 0.5190, Test Loss: 0.5191\n",
0485 "Epoch [307/400], Train Loss: 0.5189, Test Loss: 0.5196\n",
0486 "Epoch [308/400], Train Loss: 0.5190, Test Loss: 0.5204\n",
0487 "Epoch [309/400], Train Loss: 0.5188, Test Loss: 0.5200\n",
0488 "Epoch [310/400], Train Loss: 0.5190, Test Loss: 0.5254\n",
0489 "Epoch [311/400], Train Loss: 0.5190, Test Loss: 0.5213\n",
0490 "Epoch [312/400], Train Loss: 0.5189, Test Loss: 0.5206\n",
0491 "Epoch [313/400], Train Loss: 0.5188, Test Loss: 0.5239\n",
0492 "Epoch [314/400], Train Loss: 0.5189, Test Loss: 0.5198\n",
0493 "Epoch [315/400], Train Loss: 0.5190, Test Loss: 0.5172\n",
0494 "Epoch [316/400], Train Loss: 0.5187, Test Loss: 0.5199\n",
0495 "Epoch [317/400], Train Loss: 0.5189, Test Loss: 0.5190\n",
0496 "Epoch [318/400], Train Loss: 0.5188, Test Loss: 0.5195\n",
0497 "Epoch [319/400], Train Loss: 0.5187, Test Loss: 0.5191\n",
0498 "Epoch [320/400], Train Loss: 0.5188, Test Loss: 0.5213\n",
0499 "Epoch [321/400], Train Loss: 0.5190, Test Loss: 0.5191\n",
0500 "Epoch [322/400], Train Loss: 0.5188, Test Loss: 0.5250\n",
0501 "Epoch [323/400], Train Loss: 0.5187, Test Loss: 0.5187\n",
0502 "Epoch [324/400], Train Loss: 0.5188, Test Loss: 0.5219\n",
0503 "Epoch [325/400], Train Loss: 0.5186, Test Loss: 0.5200\n",
0504 "Epoch [326/400], Train Loss: 0.5187, Test Loss: 0.5229\n",
0505 "Epoch [327/400], Train Loss: 0.5187, Test Loss: 0.5191\n",
0506 "Epoch [328/400], Train Loss: 0.5187, Test Loss: 0.5193\n",
0507 "Epoch [329/400], Train Loss: 0.5188, Test Loss: 0.5193\n",
0508 "Epoch [330/400], Train Loss: 0.5187, Test Loss: 0.5209\n",
0509 "Epoch [331/400], Train Loss: 0.5187, Test Loss: 0.5222\n",
0510 "Epoch [332/400], Train Loss: 0.5187, Test Loss: 0.5207\n",
0511 "Epoch [333/400], Train Loss: 0.5187, Test Loss: 0.5180\n",
0512 "Epoch [334/400], Train Loss: 0.5186, Test Loss: 0.5229\n",
0513 "Epoch [335/400], Train Loss: 0.5186, Test Loss: 0.5184\n",
0514 "Epoch [336/400], Train Loss: 0.5187, Test Loss: 0.5188\n",
0515 "Epoch [337/400], Train Loss: 0.5187, Test Loss: 0.5204\n",
0516 "Epoch [338/400], Train Loss: 0.5185, Test Loss: 0.5292\n",
0517 "Epoch [339/400], Train Loss: 0.5186, Test Loss: 0.5214\n",
0518 "Epoch [340/400], Train Loss: 0.5187, Test Loss: 0.5210\n",
0519 "Epoch [341/400], Train Loss: 0.5187, Test Loss: 0.5220\n",
0520 "Epoch [342/400], Train Loss: 0.5186, Test Loss: 0.5202\n",
0521 "Epoch [343/400], Train Loss: 0.5185, Test Loss: 0.5311\n",
0522 "Epoch [344/400], Train Loss: 0.5186, Test Loss: 0.5209\n",
0523 "Epoch [345/400], Train Loss: 0.5187, Test Loss: 0.5205\n",
0524 "Epoch [346/400], Train Loss: 0.5185, Test Loss: 0.5178\n",
0525 "Epoch [347/400], Train Loss: 0.5185, Test Loss: 0.5186\n",
0526 "Epoch [348/400], Train Loss: 0.5186, Test Loss: 0.5228\n",
0527 "Epoch [349/400], Train Loss: 0.5184, Test Loss: 0.5222\n",
0528 "Epoch [350/400], Train Loss: 0.5186, Test Loss: 0.5246\n",
0529 "Epoch [351/400], Train Loss: 0.5185, Test Loss: 0.5195\n",
0530 "Epoch [352/400], Train Loss: 0.5185, Test Loss: 0.5189\n",
0531 "Epoch [353/400], Train Loss: 0.5184, Test Loss: 0.5212\n",
0532 "Epoch [354/400], Train Loss: 0.5186, Test Loss: 0.5267\n",
0533 "Epoch [355/400], Train Loss: 0.5184, Test Loss: 0.5227\n",
0534 "Epoch [356/400], Train Loss: 0.5183, Test Loss: 0.5213\n",
0535 "Epoch [357/400], Train Loss: 0.5183, Test Loss: 0.5214\n",
0536 "Epoch [358/400], Train Loss: 0.5183, Test Loss: 0.5228\n",
0537 "Epoch [359/400], Train Loss: 0.5185, Test Loss: 0.5195\n",
0538 "Epoch [360/400], Train Loss: 0.5183, Test Loss: 0.5234\n",
0539 "Epoch [361/400], Train Loss: 0.5185, Test Loss: 0.5198\n",
0540 "Epoch [362/400], Train Loss: 0.5183, Test Loss: 0.5193\n",
0541 "Epoch [363/400], Train Loss: 0.5185, Test Loss: 0.5208\n",
0542 "Epoch [364/400], Train Loss: 0.5185, Test Loss: 0.5217\n",
0543 "Epoch [365/400], Train Loss: 0.5184, Test Loss: 0.5258\n",
0544 "Epoch [366/400], Train Loss: 0.5185, Test Loss: 0.5199\n",
0545 "Epoch [367/400], Train Loss: 0.5183, Test Loss: 0.5197\n",
0546 "Epoch [368/400], Train Loss: 0.5182, Test Loss: 0.5189\n",
0547 "Epoch [369/400], Train Loss: 0.5184, Test Loss: 0.5191\n",
0548 "Epoch [370/400], Train Loss: 0.5183, Test Loss: 0.5198\n",
0549 "Epoch [371/400], Train Loss: 0.5182, Test Loss: 0.5230\n",
0550 "Epoch [372/400], Train Loss: 0.5183, Test Loss: 0.5178\n",
0551 "Epoch [373/400], Train Loss: 0.5183, Test Loss: 0.5184\n",
0552 "Epoch [374/400], Train Loss: 0.5182, Test Loss: 0.5187\n",
0553 "Epoch [375/400], Train Loss: 0.5183, Test Loss: 0.5224\n",
0554 "Epoch [376/400], Train Loss: 0.5184, Test Loss: 0.5203\n",
0555 "Epoch [377/400], Train Loss: 0.5183, Test Loss: 0.5197\n",
0556 "Epoch [378/400], Train Loss: 0.5182, Test Loss: 0.5196\n",
0557 "Epoch [379/400], Train Loss: 0.5181, Test Loss: 0.5200\n",
0558 "Epoch [380/400], Train Loss: 0.5182, Test Loss: 0.5264\n",
0559 "Epoch [381/400], Train Loss: 0.5181, Test Loss: 0.5202\n",
0560 "Epoch [382/400], Train Loss: 0.5182, Test Loss: 0.5219\n",
0561 "Epoch [383/400], Train Loss: 0.5183, Test Loss: 0.5219\n",
0562 "Epoch [384/400], Train Loss: 0.5183, Test Loss: 0.5196\n",
0563 "Epoch [385/400], Train Loss: 0.5181, Test Loss: 0.5192\n",
0564 "Epoch [386/400], Train Loss: 0.5180, Test Loss: 0.5206\n",
0565 "Epoch [387/400], Train Loss: 0.5181, Test Loss: 0.5264\n",
0566 "Epoch [388/400], Train Loss: 0.5181, Test Loss: 0.5191\n",
0567 "Epoch [389/400], Train Loss: 0.5183, Test Loss: 0.5240\n",
0568 "Epoch [390/400], Train Loss: 0.5182, Test Loss: 0.5196\n",
0569 "Epoch [391/400], Train Loss: 0.5182, Test Loss: 0.5176\n",
0570 "Epoch [392/400], Train Loss: 0.5181, Test Loss: 0.5219\n",
0571 "Epoch [393/400], Train Loss: 0.5181, Test Loss: 0.5198\n",
0572 "Epoch [394/400], Train Loss: 0.5180, Test Loss: 0.5251\n",
0573 "Epoch [395/400], Train Loss: 0.5181, Test Loss: 0.5182\n",
0574 "Epoch [396/400], Train Loss: 0.5182, Test Loss: 0.5212\n",
0575 "Epoch [397/400], Train Loss: 0.5180, Test Loss: 0.5217\n",
0576 "Epoch [398/400], Train Loss: 0.5181, Test Loss: 0.5213\n",
0577 "Epoch [399/400], Train Loss: 0.5182, Test Loss: 0.5180\n",
0578 "Epoch [400/400], Train Loss: 0.5180, Test Loss: 0.5205\n"
0579 ]
0580 }
0581 ],
0582 "source": [
0583 "import torch\n",
0584 "import torch.nn as nn\n",
0585 "from torch.utils.data import TensorDataset, random_split, DataLoader\n",
0586 "from torch.optim import Adam\n",
0587 "\n",
0588 "# Set device\n",
0589 "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
0590 "print(f\"Using device: {device}\")\n",
0591 "\n",
0592 "# Create multi-class labels\n",
0593 "def create_multiclass_labels(t3_isFake, t3_sim_vxy, displacement_threshold=0.1):\n",
0594 " num_samples = len(t3_isFake)\n",
0595 " labels = torch.zeros((num_samples, 3))\n",
0596 " \n",
0597 " # Fake tracks (class 0)\n",
0598 " fake_mask = t3_isFake\n",
0599 " labels[fake_mask, 0] = 1\n",
0600 " \n",
0601 " # Real tracks\n",
0602 " real_mask = ~fake_mask\n",
0603 " \n",
0604 " # Split real tracks into prompt (class 1) and displaced (class 2)\n",
0605 " prompt_mask = (t3_sim_vxy <= displacement_threshold) & real_mask\n",
0606 " displaced_mask = (t3_sim_vxy > displacement_threshold) & real_mask\n",
0607 " \n",
0608 " labels[prompt_mask, 1] = 1\n",
0609 " labels[displaced_mask, 2] = 1\n",
0610 " \n",
0611 " return labels\n",
0612 "\n",
0613 "# Create labels tensor\n",
0614 "labels_tensor = create_multiclass_labels(\n",
0615 " t3_isFake_filtered,\n",
0616 " t3_sim_vxy_filtered\n",
0617 ")\n",
0618 "\n",
0619 "# Neural network for multi-class classification\n",
0620 "class MultiClassNeuralNetwork(nn.Module):\n",
0621 " def __init__(self):\n",
0622 " super(MultiClassNeuralNetwork, self).__init__()\n",
0623 " self.layer1 = nn.Linear(input_features_numpy.shape[1], 32)\n",
0624 " self.layer2 = nn.Linear(32, 32)\n",
0625 " self.output_layer = nn.Linear(32, 3)\n",
0626 " \n",
0627 " def forward(self, x):\n",
0628 " x = self.layer1(x)\n",
0629 " x = nn.ReLU()(x)\n",
0630 " x = self.layer2(x)\n",
0631 " x = nn.ReLU()(x)\n",
0632 " x = self.output_layer(x)\n",
0633 " return nn.functional.softmax(x, dim=1)\n",
0634 "\n",
0635 "# Weighted loss function for multi-class\n",
0636 "class WeightedCrossEntropyLoss(nn.Module):\n",
0637 " def __init__(self):\n",
0638 " super(WeightedCrossEntropyLoss, self).__init__()\n",
0639 " \n",
0640 " def forward(self, outputs, targets, weights):\n",
0641 " eps = 1e-7\n",
0642 " log_probs = torch.log(outputs + eps)\n",
0643 " losses = -weights * torch.sum(targets * log_probs, dim=1)\n",
0644 " return losses.mean()\n",
0645 "\n",
0646 "# Calculate class weights (each sample gets a weight to equalize class contributions)\n",
0647 "def calculate_class_weights(labels):\n",
0648 " class_counts = torch.sum(labels, dim=0)\n",
0649 " total_samples = len(labels)\n",
0650 " class_weights = total_samples / (3 * class_counts) # Normalize across 3 classes\n",
0651 " \n",
0652 " sample_weights = torch.zeros(len(labels))\n",
0653 " for i in range(3):\n",
0654 " sample_weights[labels[:, i] == 1] = class_weights[i]\n",
0655 " \n",
0656 " return sample_weights\n",
0657 "\n",
0658 "# Print initial dataset size\n",
0659 "print(f\"Initial dataset size: {len(labels_tensor)}\")\n",
0660 "\n",
0661 "# Remove rows with NaN and update everything accordingly\n",
0662 "nan_mask = torch.isnan(input_features_tensor).any(dim=1)\n",
0663 "filtered_inputs = input_features_tensor[~nan_mask]\n",
0664 "filtered_labels = labels_tensor[~nan_mask]\n",
0665 "\n",
0666 "# Print class distribution before downsampling\n",
0667 "class_counts_before = torch.sum(filtered_labels, dim=0)\n",
0668 "print(f\"Class distribution before downsampling - Fake: {class_counts_before[0]}, Prompt: {class_counts_before[1]}, Displaced: {class_counts_before[2]}\")\n",
0669 "\n",
0670 "# Option to downsample each class\n",
0671 "downsample_classes = True # Set to False to disable downsampling\n",
0672 "if downsample_classes:\n",
0673 " # Define downsampling ratios for each class:\n",
0674 " # For example, downsample fakes (class 0) to 20% and keep prompt (class 1) and displaced (class 2) at 100%\n",
0675 " downsample_ratios = {0: 0.2, 1: 1.0, 2: 1.0}\n",
0676 " indices_list = []\n",
0677 " for cls in range(3):\n",
0678 " # Find indices for the current class\n",
0679 " cls_mask = (filtered_labels[:, cls] == 1)\n",
0680 " cls_indices = torch.nonzero(cls_mask).squeeze()\n",
0681 " ratio = downsample_ratios.get(cls, 1.0)\n",
0682 " num_cls = cls_indices.numel()\n",
0683 " num_to_sample = int(num_cls * ratio)\n",
0684 " # Ensure at least one sample is kept if available\n",
0685 " if num_to_sample < 1 and num_cls > 0:\n",
0686 " num_to_sample = 1\n",
0687 " # Shuffle and select the desired number of samples\n",
0688 " cls_indices_shuffled = cls_indices[torch.randperm(num_cls)]\n",
0689 " sampled_cls_indices = cls_indices_shuffled[:num_to_sample]\n",
0690 " indices_list.append(sampled_cls_indices)\n",
0691 " \n",
0692 " # Combine the indices from all classes\n",
0693 " selected_indices = torch.cat(indices_list)\n",
0694 " filtered_inputs = filtered_inputs[selected_indices]\n",
0695 " filtered_labels = filtered_labels[selected_indices]\n",
0696 "\n",
0697 "# Print class distribution after downsampling\n",
0698 "class_counts_after = torch.sum(filtered_labels, dim=0)\n",
0699 "print(f\"Class distribution after downsampling - Fake: {class_counts_after[0]}, Prompt: {class_counts_after[1]}, Displaced: {class_counts_after[2]}\")\n",
0700 "\n",
0701 "# Recalculate sample weights after downsampling (equal weighting per class based on new counts)\n",
0702 "sample_weights = calculate_class_weights(filtered_labels)\n",
0703 "filtered_weights = sample_weights\n",
0704 "\n",
0705 "# Create dataset with weights\n",
0706 "dataset = TensorDataset(filtered_inputs, filtered_labels, filtered_weights)\n",
0707 "\n",
0708 "# Split into train and test sets\n",
0709 "train_size = int(0.8 * len(dataset))\n",
0710 "test_size = len(dataset) - train_size\n",
0711 "train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n",
0712 "\n",
0713 "# Create data loaders\n",
0714 "train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=10, pin_memory=True)\n",
0715 "test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=10, pin_memory=True)\n",
0716 "\n",
0717 "# Initialize model and optimizer\n",
0718 "model = MultiClassNeuralNetwork().to(device)\n",
0719 "loss_function = WeightedCrossEntropyLoss()\n",
0720 "optimizer = Adam(model.parameters(), lr=0.0025)\n",
0721 "\n",
0722 "def evaluate_loss(loader):\n",
0723 " model.eval()\n",
0724 " total_loss = 0\n",
0725 " num_batches = 0\n",
0726 " with torch.no_grad():\n",
0727 " for inputs, targets, weights in loader:\n",
0728 " inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)\n",
0729 " outputs = model(inputs)\n",
0730 " loss = loss_function(outputs, targets, weights)\n",
0731 " total_loss += loss.item()\n",
0732 " num_batches += 1\n",
0733 " return total_loss / num_batches\n",
0734 "\n",
0735 "# Training loop\n",
0736 "num_epochs = 400\n",
0737 "train_loss_log = []\n",
0738 "test_loss_log = []\n",
0739 "\n",
0740 "for epoch in range(num_epochs):\n",
0741 " model.train()\n",
0742 " epoch_loss = 0\n",
0743 " num_batches = 0\n",
0744 " \n",
0745 " for inputs, targets, weights in train_loader:\n",
0746 " inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)\n",
0747 " \n",
0748 " # Forward pass\n",
0749 " outputs = model(inputs)\n",
0750 " loss = loss_function(outputs, targets, weights)\n",
0751 " epoch_loss += loss.item()\n",
0752 " num_batches += 1\n",
0753 " \n",
0754 " # Backward and optimize\n",
0755 " optimizer.zero_grad()\n",
0756 " loss.backward()\n",
0757 " optimizer.step()\n",
0758 " \n",
0759 " # Calculate average losses\n",
0760 " train_loss = epoch_loss / num_batches\n",
0761 " test_loss = evaluate_loss(test_loader)\n",
0762 " \n",
0763 " train_loss_log.append(train_loss)\n",
0764 " test_loss_log.append(test_loss)\n",
0765 " \n",
0766 " print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')"
0767 ]
0768 },
0769 {
0770 "cell_type": "code",
0771 "execution_count": 6,
0772 "metadata": {},
0773 "outputs": [
0774 {
0775 "name": "stdout",
0776 "output_type": "stream",
0777 "text": [
0778 "Baseline accuracy: 0.8611\n",
0779 "\n",
0780 "Feature importances:\n",
0781 "Feature 0 importance: 0.0541\n",
0782 "Feature 2 importance: 0.0480\n",
0783 "Feature 5 importance: 0.0434\n",
0784 "Feature 7 importance: 0.0242\n",
0785 "Feature 6 importance: 0.0223\n",
0786 "Feature 3 importance: 0.0206\n",
0787 "Feature 11 importance: 0.0167\n",
0788 "Feature 10 importance: 0.0148\n",
0789 "Feature 13 importance: 0.0140\n",
0790 "Feature 12 importance: 0.0128\n",
0791 "Feature 9 importance: 0.0114\n",
0792 "Feature 8 importance: 0.0046\n",
0793 "Feature 4 importance: 0.0016\n",
0794 "Feature 1 importance: 0.0000\n"
0795 ]
0796 }
0797 ],
0798 "source": [
0799 "import torch\n",
0800 "import numpy as np\n",
0801 "from sklearn.metrics import accuracy_score\n",
0802 "\n",
0803 "# Convert tensors to numpy for simplicity if you want to manipulate them outside of PyTorch\n",
0804 "input_features_np = input_features_tensor.numpy()\n",
0805 "labels_np = torch.argmax(labels_tensor, dim=1).numpy() # Convert one-hot to class indices\n",
0806 "\n",
0807 "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
0808 "\n",
0809 "def model_accuracy(features, labels, model):\n",
0810 " \"\"\"\n",
0811 " Compute accuracy for a multi-class classification model\n",
0812 " that outputs probabilities of size [batch_size, num_classes].\n",
0813 " \"\"\"\n",
0814 " model.eval() # Set the model to evaluation mode\n",
0815 " \n",
0816 " # Move the features and labels to the correct device\n",
0817 " inputs = features.to(device)\n",
0818 " labels = labels.to(device)\n",
0819 " \n",
0820 " with torch.no_grad():\n",
0821 " outputs = model(inputs) # shape: [batch_size, num_classes]\n",
0822 " # For multi-class, the predicted class is argmax of the probabilities\n",
0823 " predicted = torch.argmax(outputs, dim=1)\n",
0824 " # Convert one-hot encoded labels to class indices if needed\n",
0825 " if len(labels.shape) > 1:\n",
0826 " labels = torch.argmax(labels, dim=1)\n",
0827 " # Compute mean accuracy\n",
0828 " accuracy = (predicted == labels).float().mean().item()\n",
0829 " \n",
0830 " return accuracy\n",
0831 "\n",
0832 "# Compute baseline accuracy\n",
0833 "baseline_accuracy = model_accuracy(input_features_tensor, labels_tensor, model)\n",
0834 "print(f\"Baseline accuracy: {baseline_accuracy:.4f}\")\n",
0835 "\n",
0836 "# Initialize array to store feature importances\n",
0837 "feature_importances = np.zeros(input_features_tensor.shape[1])\n",
0838 "\n",
0839 "# Iterate over each feature for permutation importance\n",
0840 "for i in range(input_features_tensor.shape[1]):\n",
0841 " # Create a copy of the original features\n",
0842 " permuted_features = input_features_tensor.clone()\n",
0843 " \n",
0844 " # Permute feature i across all examples\n",
0845 " # We do this by shuffling the rows for that specific column\n",
0846 " permuted_features[:, i] = permuted_features[torch.randperm(permuted_features.size(0)), i]\n",
0847 " \n",
0848 " # Compute accuracy after permutation\n",
0849 " permuted_accuracy = model_accuracy(permuted_features, labels_tensor, model)\n",
0850 " \n",
0851 " # The drop in accuracy is used as a measure of feature importance\n",
0852 " feature_importances[i] = baseline_accuracy - permuted_accuracy\n",
0853 "\n",
0854 "# Sort features by descending importance\n",
0855 "important_features_indices = np.argsort(feature_importances)[::-1]\n",
0856 "important_features_scores = np.sort(feature_importances)[::-1]\n",
0857 "\n",
0858 "# Print out results\n",
0859 "print(\"\\nFeature importances:\")\n",
0860 "for idx, score in zip(important_features_indices, important_features_scores):\n",
0861 " print(f\"Feature {idx} importance: {score:.4f}\")"
0862 ]
0863 },
0864 {
0865 "cell_type": "code",
0866 "execution_count": 8,
0867 "metadata": {},
0868 "outputs": [
0869 {
0870 "name": "stdout",
0871 "output_type": "stream",
0872 "text": [
0873 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer1[32] = {\n",
0874 "-0.9152892f, 3.2650192f, -0.4164221f, -0.1210157f, -2.4165483f, -1.0984275f, -2.1654966f, -0.8991888f, -0.0503724f, 7.1305695f, -5.2781415f, 3.2997849f, 1.0025330f, -0.5117974f, 0.2957068f, -0.1811045f, -2.7853479f, 1.8040915f, -2.8807588f, -4.6462102f, 1.2869841f, -0.0526987f, 0.4946094f, 2.6554070f, -0.1360572f, 0.2122774f, 4.7361507f, -1.4605266f, 0.1759245f, -0.7966636f, -0.0401897f, -0.2652957f };\n",
0875 "\n",
0876 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer1[14][32] = {\n",
0877 "{ 0.2570587f, 1.5017653f, -1.8436140f, 1.6314303f, -0.1464428f, 1.2261974f, 2.8629315f, -0.0778951f, -0.0007868f, -2.4665442f, 3.7231014f, -0.4062112f, 5.0222125f, -0.4256854f, -0.8145034f, -0.0993065f, 1.1874412f, 3.7737985f, -2.0898068f, 5.0041976f, -0.4184950f, 0.0133298f, -1.1757115f, 0.8953519f, -0.2589224f, 3.4567924f, -1.0867721f, -0.0325336f, -0.1398652f, 5.9361205f, -0.2938714f, 0.0110872f },\n",
0878 "{ 0.0062326f, 0.0294117f, -0.1038531f, -0.1871421f, 0.0092176f, 0.0194613f, -0.0970159f, 0.0044040f, 0.0040717f, -0.0464808f, -0.0161564f, 0.0660082f, 0.0107912f, 0.0041196f, 0.0076985f, -0.0447547f, -0.0234053f, -0.0128952f, -0.0374210f, 0.0277440f, 0.0126392f, 0.0053390f, -0.0176450f, 0.0422363f, -0.1089868f, 0.0229251f, -0.0632515f, -0.0000745f, -0.0120581f, 0.0553841f, 0.1958316f, -0.2002713f },\n",
0879 "{ -0.3374411f, 1.0331864f, -2.3923049f, -1.7857485f, 5.2973804f, -2.4907997f, -5.4545326f, -0.1601160f, -0.0028237f, 0.3510691f, 1.9307067f, 0.1516920f, -9.8718386f, -0.0893254f, -0.3546923f, 0.1746188f, -2.5511057f, -6.9032016f, 1.1837323f, -1.5620790f, 0.3867095f, -0.0118511f, 0.0633005f, -0.4638650f, -0.1623496f, -6.3706584f, -1.3387657f, -0.3824906f, 0.0013318f, -0.2500904f, 0.0434040f, 0.2572088f },\n",
0880 "{ 0.4586262f, -5.7731452f, 2.1999521f, 1.8049341f, -0.2857220f, 0.9761068f, 7.4085426f, 0.6136439f, -0.0216335f, -2.1852305f, -0.8797144f, -0.0105609f, 2.9077356f, 0.6365089f, 0.4854219f, 0.1170259f, 1.9888018f, 1.5127803f, -1.5444536f, 8.0876036f, 0.2033372f, -0.0132441f, -0.8586000f, -1.9558586f, -0.0361535f, 4.7021170f, 1.0431648f, 2.0264521f, -0.2665041f, 6.2334776f, -0.0008584f, -0.0026126f },\n",
0881 "{ -1.5857489f, 2.3446348f, -14.6416082f, 2.6467063f, 3.0982802f, 14.1958466f, 7.0268164f, -0.4456313f, 0.0005447f, -1.3794490f, -10.3501320f, 1.1612288f, 9.6774111f, -0.1875248f, 21.0353413f, -0.1961844f, 8.0234823f, -4.1653094f, -7.1429234f, 15.8104372f, 2.7506628f, -0.0142524f, 22.4932308f, 0.7682357f, 0.0385473f, 13.5405970f, -4.9976201f, -22.4438667f, -0.1693429f, -18.7507858f, -0.0939464f, 0.2192265f },\n",
0882 "{ 65.2134705f, 1.4352289f, -0.6685436f, 22.9858570f, 1.6736605f, -0.1810708f, 0.5204540f, -53.2590408f, -0.0630155f, 3.6024182f, -3.8777969f, 5.0021510f, 0.0055030f, 23.8294449f, -1.3818942f, -0.2419317f, -10.2253504f, -1.8309352f, 1.7169305f, -0.3938941f, 9.1144180f, -0.0004920f, -3.1774588f, -36.4919891f, -0.1711030f, 6.5288949f, 4.5861993f, 0.8314257f, 0.0305954f, 5.1864023f, 0.2658210f, -0.1748345f },\n",
0883 "{ -0.7423377f, -5.8733273f, 5.0070434f, -6.5473375f, 3.8788877f, 6.8001447f, 3.3014266f, -0.5657389f, -0.0376282f, 6.1230965f, 4.0765481f, 0.1596697f, -7.4254904f, 0.2356068f, 3.5560462f, -0.1223621f, 0.7022666f, -9.2908258f, -4.4684458f, -1.1861130f, 0.1879538f, -0.0337607f, 0.9228330f, 1.2672698f, -0.1690857f, -11.2556086f, 4.1028724f, 0.4850706f, 0.0041372f, 3.7200036f, 0.1445599f, 0.2449260f },\n",
0884 "{ -2.3848257f, -1.7144806f, -0.2987937f, 10.1947727f, 0.7855392f, 7.1466241f, -2.2256043f, -1.2184218f, 0.1233135f, 5.3274498f, 2.9086673f, 1.5096599f, 2.8449385f, -0.2345320f, -2.2044909f, 0.1858539f, -0.7592235f, 3.1651254f, 0.5184333f, -3.7233777f, 1.5772278f, 0.0997663f, -0.0325775f, -13.2207737f, -0.0340279f, -6.4953661f, -18.9173355f, 0.9963044f, 0.1927230f, 4.6532283f, -0.0916147f, -0.0406466f },\n",
0885 "{ 5.0975156f, -0.7078310f, 26.7917671f, -1.9278797f, -2.4459264f, -12.1174421f, -5.1347051f, 2.0090365f, -0.0012259f, -11.6201696f, 24.7306499f, -8.7715597f, 4.8136749f, -0.7106145f, -46.1458054f, 0.1771528f, 22.9087524f, -5.5876012f, 6.9944615f, 29.2786064f, -7.2195830f, 0.0270186f, 14.7860146f, 1.0168871f, 0.1467975f, -19.6260185f, 3.3284237f, 22.5860500f, -0.0160821f, -35.1570702f, 0.1473373f, 0.1500054f },\n",
0886 "{ 53.6092033f, -0.7677985f, 3.3910768f, -27.1046524f, -5.1561117f, 1.6982658f, 0.9386115f, -72.5867996f, 57.8674583f, -8.2815771f, 3.2216380f, -28.9387760f, 3.2793560f, 37.3099365f, 2.2979465f, 0.1827718f, 6.8113675f, 1.0104842f, -1.5407079f, -2.5780137f, 32.4788666f, -61.8150520f, -2.6497467f, -10.6412830f, -0.2596186f, -11.8458385f, 27.5528336f, -1.3142428f, -0.2566442f, -17.8737431f, 0.0261727f, 0.0107839f },\n",
0887 "{ 0.2419503f, -3.8581870f, 6.1144238f, 2.3472424f, 1.9470867f, -3.3741038f, 0.4852004f, 0.6366135f, -0.0736884f, 2.2963598f, -0.4113644f, -2.0738223f, -3.7331092f, 0.7157578f, 1.2168316f, -0.1584605f, -2.3843584f, -6.1547771f, -0.4764193f, 4.6278925f, -1.3195806f, -0.0717061f, 2.5889101f, -3.7769084f, 0.0527631f, 3.8808305f, -0.0672685f, 0.3294500f, -0.1916338f, -2.2346771f, -0.1518883f, 0.0462940f },\n",
0888 "{ 0.7453216f, -0.5999010f, -2.6196895f, -6.5323448f, 0.0482639f, 0.0162446f, -0.1185504f, -2.6736362f, -0.0037108f, -4.2818441f, 4.1449761f, 3.3861248f, 0.1735097f, -4.7952204f, 0.8002076f, 0.0598137f, 0.2611136f, 2.4648571f, 1.3178054f, -12.6864462f, -1.4815618f, -0.0113616f, 0.0697572f, 11.0613089f, -0.1636935f, -1.3598750f, 5.4537063f, 0.4077346f, 0.2614925f, 0.4335817f, -0.1616396f, 0.0372773f },\n",
0889 "{ 0.3573107f, -0.1230106f, -0.0517133f, -0.7743639f, 0.1088887f, -0.0315369f, -0.4702122f, 0.4170579f, 0.0149554f, -2.9016135f, 0.8456838f, -1.4586754f, -0.3096066f, 0.3871936f, -0.0811070f, -0.0972313f, 1.3539683f, -0.4489806f, 2.1791372f, -0.0245313f, -0.5678235f, 0.0153700f, 0.0444115f, -1.1144291f, -0.0992134f, -0.0615626f, -1.5467011f, 0.3384680f, -0.2377923f, -3.0146224f, -0.1680345f, -0.0683730f },\n",
0890 "{ 17.0918045f, 0.3651829f, 0.4609411f, -8.8885517f, -1.3358241f, 0.3141162f, 0.5917582f, -3.1579833f, 18.4088402f, -0.7021288f, -0.1767638f, -20.7704201f, 1.0183893f, -12.4671431f, 0.0741675f, 0.2120477f, -0.4298171f, -0.3993036f, -0.4320501f, 0.1840025f, 35.6576691f, -19.6535892f, -1.1798559f, -4.8292837f, 0.1928347f, -2.1754487f, 5.7580199f, -0.4750445f, -0.0005913f, -3.2222869f, -0.0762974f, -0.0288493f },\n",
0891 "};\n",
0892 "\n",
0893 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer2[32] = {\n",
0894 "-2.9135964f, 2.8539734f, 0.1411822f, 1.3484665f, 4.0785451f, 3.1302166f, 0.6947064f, -6.6154590f, 0.5603381f, 1.1026498f, -0.0329598f, 5.5717101f, -4.4454126f, 0.8731447f, -0.0039664f, 3.3978446f, 0.2816379f, 1.0174516f, 5.7364573f, -0.2107503f, 1.5612308f, 3.0443959f, -0.5723908f, -2.2100310f, 3.2695763f, 1.2092905f, 0.3386045f, 2.9886062f, 1.6525602f, 2.2572896f, 1.5943902f, 3.8117774f };\n",
0895 "\n",
0896 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer2[32][32] = {\n",
0897 "{ -34.7392120f, -3.6741450f, 0.9912871f, 2.2665858f, -4.2091088f, -5.7695718f, -0.7361242f, 1.9049225f, -2.0010724f, 0.5481573f, -0.1977228f, -8.8578167f, 0.2730211f, 0.2286723f, -0.2417704f, 3.0146859f, -4.6289725f, -1.9597307f, -3.8090327f, -0.0440762f, -0.1491523f, 7.2579861f, -1.3858653f, 5.7306480f, 1.3374176f, -58.5746918f, -0.7968796f, -7.3152990f, 0.8867928f, -1.0549910f, -1.5279640f, -0.3361593f },\n",
0898 "{ 0.1808980f, -0.3280856f, 0.2590686f, -1.0191135f, -0.4207653f, -0.7095064f, 0.1365926f, 0.4096814f, -10.8506107f, 1.2207726f, -0.1571288f, -0.7078422f, 5.8424225f, 1.6238457f, -0.0862379f, -1.4228019f, 0.4024760f, 1.8243361f, -0.6990436f, 0.0060351f, -0.1591629f, -4.6918707f, -0.1512203f, 0.0915786f, -0.5897132f, -0.2044267f, -0.9305406f, -0.5472224f, -2.8864694f, 1.0543642f, 0.5347654f, -0.5453511f },\n",
0899 "{ 0.0880480f, 0.4925799f, -2.0632777f, 0.7970945f, 0.1767638f, 2.1716123f, -1.6498836f, 2.0251210f, 3.4944487f, 4.9597678f, 0.0147583f, 1.4637951f, -1.9777663f, -0.4427343f, -0.2699612f, 0.2240291f, 1.4522513f, 4.8733187f, 1.2140924f, -0.1049215f, -0.6172752f, -1.0545660f, -0.0994265f, 0.0489118f, 0.1526600f, -0.5299308f, 1.2835877f, 1.1346543f, 3.8537612f, 0.0424094f, 0.6743001f, 3.6908419f },\n",
0900 "{ 13.5742407f, -0.2112840f, -0.1391970f, -0.2532290f, -6.1313200f, -0.0358473f, -1.1182036f, -1.2637638f, -0.9513455f, -1.3301764f, -0.0799558f, -1.3724291f, -1.6102096f, -0.0269820f, -0.2672046f, -4.3722496f, 0.5367125f, -0.8932595f, 2.7968166f, -0.1038471f, 0.9354176f, 11.1879988f, 2.5031428f, -0.5350324f, 0.9403298f, -23.2240810f, -1.7667898f, -0.9982644f, 0.7523674f, 0.1651796f, 0.6430404f, 0.1910793f },\n",
0901 "{ 0.2658246f, 0.7314553f, -0.1396945f, 0.8241134f, -0.0880722f, 0.5978789f, 0.6997685f, -0.6765627f, 2.1706738f, -0.8361687f, -0.0501303f, 0.3658123f, -6.4972825f, -4.6433911f, -0.0610007f, -1.3655313f, 0.8686107f, -0.5963267f, 0.0234031f, -0.2062150f, -0.2529819f, -8.4623156f, -0.4437916f, -1.9638975f, -0.3870836f, -0.5850767f, -4.6490083f, 0.6830147f, -1.2122021f, -2.5658479f, 0.4557288f, -0.0879869f },\n",
0902 "{ -0.0108578f, 0.5193277f, 0.8980315f, 0.6727697f, -0.2764051f, 0.5845009f, 0.2578916f, 0.5809379f, 0.1509607f, 0.6870583f, -0.1276244f, -0.0574215f, -4.9936194f, 0.5157762f, -0.2816216f, -1.3928800f, 0.4075674f, 1.1707770f, 0.0359891f, -0.0638050f, -0.7958111f, -13.7506857f, 0.9970800f, 0.3291980f, -0.0093937f, -0.8198650f, 2.4517658f, 0.0728102f, 0.5025729f, 1.1547782f, -2.2815955f, 0.7980155f },\n",
0903 "{ -39.5117416f, -1.9099823f, -0.5987001f, -0.1347036f, 0.2869921f, -6.9140315f, 1.1130731f, 1.5080519f, 2.2398031f, -0.0088237f, -0.0412525f, -6.1471343f, 2.8918502f, -0.3072237f, -0.1006490f, -1.7464429f, -0.8411779f, -42.2267303f, 0.7191703f, -0.1617323f, -0.2711441f, -0.8011407f, 0.5901257f, 1.3762408f, 0.3728105f, 0.5847842f, -0.7151731f, -5.5479650f, 0.4756111f, 0.1937995f, -0.2577415f, 1.4683112f },\n",
0904 "{ 5.5796323f, -1.9758234f, -3.9855742f, -1.1933511f, -6.1229320f, -3.4628909f, 0.8108729f, 1.1214327f, 0.2925960f, -0.3177399f, -0.0674356f, -7.5650120f, -0.2327600f, -0.1809894f, -0.2001229f, -3.7608955f, -1.6911368f, -1.2986628f, -10.9993305f, -0.0333699f, 2.1555903f, 4.0816026f, -1.7761034f, 2.7839565f, 1.7967551f, -1.9129136f, 0.3337919f, -5.5511317f, -0.3778784f, -1.7885938f, 0.5358973f, 0.3659674f },\n",
0905 "{ -0.4947478f, -1.1247686f, -0.5239698f, -0.2116186f, 0.9495451f, -1.7275617f, 3.2058396f, 5.8543334f, 1.1329687f, 2.0362375f, -0.0113123f, -3.6117337f, -2.6473372f, -101.7068253f, 0.0933132f, -9.5351706f, 2.7397399f, -6.7515664f, 6.8541646f, -0.1772990f, -8.0234919f, -11.7267113f, -79.2402649f, 1.0556825f, -1.2873408f, -204.0993347f, 2.9009213f, -1.5086528f, -0.3790632f, -16.6444530f, -0.7742234f, -2.5006552f },\n",
0906 "{ -3.5879157f, 1.2350843f, 0.9988346f, 0.8656718f, -0.2288211f, 1.5041313f, 1.2635980f, -0.9575849f, 1.4612056f, -0.6694647f, -0.2607012f, 2.2113328f, -0.1756403f, 0.4053871f, -0.2502097f, -0.9793193f, 1.1052762f, 0.1973925f, -0.6196574f, -0.1616422f, -0.0878458f, -33.2044487f, -0.6067256f, -2.6461327f, -0.8652852f, 0.4787520f, 0.0987720f, 2.0803480f, -2.9416194f, 0.6944812f, -0.7692280f, -0.2723561f },\n",
0907 "{ 0.9679711f, -0.8438777f, 1.6422297f, 0.1470843f, -0.0370339f, -1.7256715f, -3.5553441f, 0.9303662f, -3.7400429f, -0.0415591f, -0.0421884f, -0.4996702f, 0.2560562f, -1.1471199f, -0.1010781f, 2.9399052f, -2.7095137f, 0.1029265f, -0.9041355f, -0.0961986f, -1.1163988f, 2.2019930f, 1.1136653f, 2.6281602f, 0.1038788f, 0.8405868f, -1.7726818f, -1.4164388f, 2.5241690f, -0.4696356f, -0.6558685f, -0.3892841f },\n",
0908 "{ 0.3616795f, -0.2340108f, 1.3857211f, -0.1612080f, 0.9073119f, 2.0731826f, -2.9476249f, -0.8346572f, -0.2136685f, -1.5312109f, -0.2896508f, 3.2546318f, -0.8377571f, -0.1631008f, -0.2456917f, 2.3952994f, 0.7205132f, 0.3009275f, -0.8928477f, 0.0297560f, -0.9989902f, -5.8525624f, -1.1516922f, -1.4669514f, -0.9353167f, 0.0053586f, -0.2916538f, 2.6023183f, -0.3268352f, 1.1982346f, 0.5936528f, 0.2427792f },\n",
0909 "{ -8.4515934f, 1.6914611f, 1.8740683f, -3.0301368f, 1.1796944f, 1.6180993f, 1.5193824f, -1.3537951f, 1.8400136f, -2.8482721f, -0.1556847f, 1.9412404f, 4.7879019f, 0.7858805f, 0.0379962f, -2.2685356f, 2.4616067f, -3.1357207f, -3.1058917f, -0.0751680f, -6.3824587f, -6.4688945f, -0.4577174f, -4.3322401f, -5.3993649f, -0.0399034f, 0.6397164f, 2.3432400f, -3.4454954f, 0.4422087f, 2.4481769f, -2.0406303f },\n",
0910 "{ 2.0977514f, 2.0378633f, -5.0659881f, -3.1632454f, -3.7596478f, 1.1832924f, 2.6347756f, -0.7441098f, 0.9281505f, 0.2330403f, -0.1830614f, -0.7371528f, 0.9002755f, 0.2897577f, 0.0216177f, -6.7740455f, 2.5610964f, 0.7834055f, -7.0665541f, -0.0497221f, 2.1334522f, -4.0962648f, -0.9172248f, -1.9944772f, 0.2762617f, -7.6493464f, 1.3044875f, 0.1891927f, -1.2570183f, -0.9203181f, 1.1330502f, 0.2542208f },\n",
0911 "{ -3.2519467f, 0.1621506f, 1.2959863f, -0.4015787f, -0.1872185f, -1.3774366f, -0.6622453f, -0.9865818f, -0.7389784f, -2.5770009f, -0.1067092f, -1.8156646f, 8.1435080f, 0.1961913f, -0.2026473f, -0.5137898f, -0.3847715f, -2.7479975f, -0.5669084f, -0.0805389f, 0.8162143f, -89.8910904f, -0.0983714f, -0.6288267f, -0.1096447f, 0.3820810f, -0.5419686f, -1.2580637f, -0.1827327f, -0.7081568f, -0.0991779f, -1.8233751f },\n",
0912 "{ -0.1124980f, -0.0672807f, -0.1147610f, 0.0497073f, 0.1218153f, 0.1264220f, 0.0282118f, 0.1016580f, 0.0696730f, -0.0004516f, 0.0812953f, 0.0875243f, 0.0982849f, 0.1660293f, -0.1481024f, -0.0315878f, -0.0644747f, 0.1398649f, 0.0835874f, -0.1440294f, -0.1390193f, -0.0628670f, -0.1517032f, -0.0325693f, -0.1094708f, 0.0963070f, 0.0056602f, -0.0197677f, -0.0068012f, 0.1578562f, -0.0302607f, 0.0684079f },\n",
0913 "{ 0.2561692f, 0.6044238f, -1.0067526f, 3.0207021f, -0.5215638f, -1.5455573f, -2.4320843f, -0.2874290f, -5.5609145f, -2.5270512f, -0.1884816f, -1.4440522f, 1.1501647f, 1.2767099f, -0.2626259f, 1.5462712f, -0.3342260f, -1.4259053f, -0.1591775f, -0.1777169f, -0.8070273f, -0.6262614f, -0.1421982f, 0.4950019f, 0.3899588f, 1.1158458f, -1.8252333f, -1.4090738f, -0.9128270f, 1.2212859f, -0.5060611f, -1.5151979f },\n",
0914 "{ 6.7118378f, -1.9413553f, -0.9765822f, 1.9195900f, -1.6302279f, -3.3607423f, -0.5215842f, 1.6841047f, -3.7323046f, 2.9130237f, -0.2912694f, -1.6349151f, -2.9017780f, 0.8473414f, -0.0895011f, 1.9765211f, -1.6982929f, 1.3711845f, -0.8770422f, -0.0966949f, -4.9838910f, -80.1754532f, 0.5617938f, 5.0638437f, -2.1738896f, 1.1080216f, -1.2562524f, -2.4832511f, 1.9475292f, -0.0768876f, -1.7405595f, 1.2659055f },\n",
0915 "{ -9.3639793f, -0.7770892f, 0.5863696f, -2.4971461f, 0.9785317f, 0.6006397f, -0.7508336f, 1.4561496f, 3.3278019f, -1.3552047f, 0.0016642f, -1.0065420f, -0.9822129f, -0.1398876f, -0.0867651f, -2.2540245f, 1.3095651f, -0.4880162f, -0.9081959f, 0.0172203f, -0.9673725f, -6.6905494f, 1.8820195f, -0.3343536f, -0.9252734f, -0.4198346f, 3.2226353f, 0.0417345f, -0.2720280f, -0.4798162f, 0.8319970f, 0.6051134f },\n",
0916 "{ 0.0971739f, -0.1656290f, -3.0162272f, 1.3674160f, -0.1130774f, -2.0347462f, 0.8287644f, 1.7546188f, -0.8183546f, -0.2517924f, -0.0338358f, -2.3002670f, -1.8175446f, -0.8929357f, 0.0145055f, 0.7129369f, -1.2497267f, -0.8694588f, 0.6608886f, 0.0472852f, 0.7463530f, 2.0970786f, -0.7406388f, 1.6379862f, 0.9597634f, 0.2664887f, -0.9611518f, -2.1256742f, -0.8108851f, -0.7670876f, -0.2143202f, -0.9296412f },\n",
0917 "{ -5.5468359f, 1.8933475f, 0.5377558f, -0.8609743f, 1.6522382f, 3.7070215f, -1.3910099f, -1.7996042f, -1.2547578f, -0.3161051f, -0.1433857f, 5.9167895f, 0.2788936f, 0.6513762f, -0.1890229f, 1.3976518f, 2.5647829f, 1.7091097f, -2.4891980f, 0.0704016f, -1.2354076f, -7.9673457f, 0.5024176f, -3.1194675f, -1.8552034f, 1.4241451f, -0.5721908f, 4.6941881f, -0.4191083f, 1.5897839f, 0.5376836f, -0.3906031f },\n",
0918 "{ -7.3998523f, -1.7208674f, 3.6660385f, 3.2399852f, 2.6726487f, -2.7743144f, 2.6148691f, 5.4286017f, -3.5616024f, 3.8747442f, -0.2854572f, -1.7255206f, 1.5527865f, -95.2269287f, -0.0130005f, -0.3984787f, -0.3650612f, -5.9493575f, 4.3472433f, 0.0598797f, -11.5429420f, -2.9780169f, -69.4482956f, 3.8544486f, -2.9926283f, 3.9207959f, 0.2614698f, -1.5368384f, 1.2026052f, -10.9552374f, -2.5336740f, -4.0654378f },\n",
0919 "{ 2.1074338f, -2.0316105f, -0.2519890f, 0.0232255f, -0.2889173f, -0.1693905f, 0.4285150f, 0.6449150f, 4.6293564f, 3.3936121f, 0.0660587f, 0.3134870f, -2.3245549f, -0.9685450f, 0.0889201f, 0.5934311f, -0.9143091f, 2.3384421f, -0.4089483f, -0.1643694f, 2.5919075f, 6.0844655f, 0.2091536f, 2.0565152f, 1.1226704f, -0.2695110f, 0.3927289f, -0.0457220f, 4.0436058f, -1.5475131f, -1.3438802f, 1.6676767f },\n",
0920 "{ 1.7787290f, 0.3969800f, 1.7834123f, 1.3779058f, -0.5738219f, -0.5349790f, -1.4947498f, -0.0787759f, 0.0341407f, 0.4346433f, -0.1981957f, -0.2886125f, -1.0133898f, -0.7178908f, -0.1872994f, 0.7944858f, -1.4787679f, -0.2754479f, -3.5224824f, -0.2090070f, -1.1161381f, -3.6711519f, -1.7022727f, 0.1558363f, 0.4152904f, -2.6727507f, 1.0731899f, -0.3006089f, 0.1950178f, -1.4062581f, -1.4458562f, 0.2443156f },\n",
0921 "{ -0.1570787f, 0.0073413f, -0.1335499f, -0.1712597f, 0.0029127f, 0.1628582f, -0.0816609f, -0.1307715f, -0.1621102f, -0.1200016f, -0.1394555f, 0.0157797f, -0.1572116f, 0.1745119f, -0.1128801f, -0.0566642f, 0.0099119f, -0.1222350f, 0.0299575f, -0.1031234f, -0.1048335f, 0.1707117f, -0.1490631f, -0.0835587f, -0.1712185f, -0.1278749f, 0.1462234f, -0.0081762f, -0.1106477f, -0.1645385f, 0.1268658f, 0.1686065f },\n",
0922 "{ -2.2590189f, -0.1024268f, -1.9020500f, -0.7051241f, 0.1037211f, 0.1701183f, 0.1889226f, 0.9506961f, 1.4137075f, 0.4496680f, -0.2055015f, -0.5990168f, -6.5227470f, -0.1113795f, -0.1070101f, 0.0105921f, -0.1653819f, 0.8838411f, -0.4713951f, 0.0250525f, 0.5694079f, -63.6874619f, 0.3740432f, 0.2925327f, 0.2328388f, -0.9265066f, 0.3290201f, -0.3581912f, 0.8044130f, -0.0143339f, 0.6609334f, -0.6653876f },\n",
0923 "{ 1.4302264f, 0.2180429f, 0.9684587f, 1.0369793f, 0.1597588f, -0.7066790f, -1.7150335f, 0.1960071f, -0.1694691f, 0.8381965f, -0.0181254f, -1.8366945f, -1.8840518f, -0.3109443f, -0.0058080f, 2.0794308f, -1.7936089f, -0.4478118f, -1.2889514f, -0.0300996f, -0.5915933f, -0.8868528f, 1.2223560f, 0.6542705f, 0.0814525f, -1.3704894f, -0.1875549f, -1.6079675f, -0.2744675f, 0.0382733f, -0.9821799f, -1.1006635f },\n",
0924 "{ -0.3911386f, 0.0468989f, 1.9009087f, -1.6725038f, 0.4506643f, -1.9519551f, 0.8855276f, -1.5861524f, 0.3190332f, -3.1724985f, -0.0278042f, -1.2427157f, 1.6820471f, 0.1633015f, -0.0449006f, -1.6101545f, -0.1007412f, -2.7659879f, -0.5162025f, -0.1431058f, 0.8236251f, -0.9194202f, -0.1490582f, -1.6231275f, -0.5467592f, 0.1333764f, -0.4865468f, -0.8269796f, -0.9018597f, 0.0288408f, -1.0994427f, -2.7987468f },\n",
0925 "{ 0.1278387f, -0.0134571f, 0.0448454f, -0.1556552f, -0.1247998f, -0.1196313f, -0.1611872f, -0.0630336f, -0.1410615f, 0.1682803f, 0.0263861f, 0.0619852f, 0.0423437f, 0.0982059f, 0.0784064f, 0.1412098f, 0.0331818f, -0.1537199f, -0.1165343f, -0.0441304f, 0.0197925f, -0.1256299f, 0.0694126f, -0.0137156f, 0.1587864f, 0.1131037f, -0.0722358f, 0.1287198f, -0.0683723f, -0.1212666f, 0.0847685f, 0.1469466f },\n",
0926 "{ -1.7708958f, 0.2500555f, -1.0356978f, -1.4820993f, -0.9565629f, 2.2127695f, 1.3409303f, -0.6528702f, 1.4306690f, 1.7529922f, 0.0491593f, 0.8595024f, -1.1016617f, -0.1608696f, -0.1200257f, -1.9610568f, 2.6189950f, 1.8707203f, 0.5241567f, -0.2288505f, 0.1528303f, -127.8296814f, -166.2449646f, -1.0174949f, -0.6682033f, -0.9169813f, 1.5756097f, 1.4574182f, -0.1463246f, -1.8262713f, 0.7517605f, -0.1181977f },\n",
0927 "{ -0.0324178f, -0.0418596f, -0.1287051f, -0.0232098f, -0.0512466f, -0.0905093f, -0.1104402f, -0.0095842f, 0.1413968f, -0.0081470f, -0.0251773f, 0.0667293f, 0.0344667f, 0.0116366f, -0.0908088f, -0.0980062f, 0.1874590f, -0.0381802f, 0.0684232f, 0.0252469f, -0.0681347f, 0.1034415f, 0.0576827f, -0.0557779f, 0.0868192f, -0.0851723f, -0.0868760f, 0.1192429f, 0.1751331f, -0.0323825f, -0.1238438f, -0.0623215f },\n",
0928 "{ 0.1757070f, -0.1212057f, -0.0878934f, 0.0737142f, 0.0712249f, -0.0818311f, -0.0719173f, -0.0561241f, 0.0630706f, -0.1523757f, -0.0048847f, 0.1597463f, -0.0302248f, -0.0096164f, -0.0259278f, -0.0815664f, -0.1283869f, 0.1644790f, -0.1612884f, 0.1505984f, -0.1614616f, -0.0756450f, -0.1680063f, -0.0716024f, -0.1266488f, 0.1165592f, -0.0066066f, 0.0661669f, 0.0148620f, 0.0464089f, 0.1496351f, -0.1720888f },\n",
0929 "};\n",
0930 "\n",
0931 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_output_layer[3] = {\n",
0932 "-0.3838706f, -0.0366794f, 0.5841699f };\n",
0933 "\n",
0934 "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_output_layer[32][3] = {\n",
0935 "{ 0.6237589f, 0.2710748f, 0.5615537f },\n",
0936 "{ -0.1665458f, 0.3942705f, 0.2601272f },\n",
0937 "{ 0.3388835f, 0.1579971f, 0.0178280f },\n",
0938 "{ 0.5823844f, -0.0299621f, 0.1178701f },\n",
0939 "{ 0.5561634f, 0.1805784f, 0.6629463f },\n",
0940 "{ 0.1693098f, -0.8297758f, 0.1556239f },\n",
0941 "{ 0.0062806f, 0.2958559f, 0.2698825f },\n",
0942 "{ -0.3925241f, 0.1489681f, -0.0803940f },\n",
0943 "{ 0.5710047f, 0.1924859f, 0.2375189f },\n",
0944 "{ -0.0372825f, 0.0286687f, 0.2910011f },\n",
0945 "{ -0.0867018f, -0.1508995f, -0.0193411f },\n",
0946 "{ 0.4878173f, -0.9407690f, 0.3869846f },\n",
0947 "{ 0.9613981f, 0.3148000f, 0.2196945f },\n",
0948 "{ 0.5831478f, 1.2141191f, 0.7358299f },\n",
0949 "{ -0.0073579f, -0.0419888f, 0.0338354f },\n",
0950 "{ 0.2477632f, 0.9092489f, 0.7818094f },\n",
0951 "{ 0.3554717f, -0.4452990f, 0.0102171f },\n",
0952 "{ 0.3888267f, 0.7089493f, 0.3766315f },\n",
0953 "{ 0.8450955f, -0.0079020f, 0.5853269f },\n",
0954 "{ 0.0646952f, 0.0271975f, 0.0329916f },\n",
0955 "{ 0.5528679f, 0.0075829f, 0.2414524f },\n",
0956 "{ -1.3869698f, -1.1617719f, -1.1356672f },\n",
0957 "{ 0.0214099f, 0.3563140f, 0.5346315f },\n",
0958 "{ 0.3791857f, -0.2714695f, -0.0823861f },\n",
0959 "{ -0.3221727f, 0.5334318f, 0.1581419f },\n",
0960 "{ 0.6678535f, 0.6672282f, 0.4110478f },\n",
0961 "{ 0.1442596f, 0.0245941f, -0.1659890f },\n",
0962 "{ -0.9674007f, 1.4712439f, -0.8418093f },\n",
0963 "{ 0.5696401f, 0.2636259f, 0.2079044f },\n",
0964 "{ 0.0382360f, 0.2687068f, 0.4462553f },\n",
0965 "{ -0.0957586f, 0.4259349f, 0.3613387f },\n",
0966 "{ -0.0633585f, 0.4451550f, 0.2848748f },\n",
0967 "};\n",
0968 "\n"
0969 ]
0970 }
0971 ],
0972 "source": [
0973 "def print_formatted_weights_biases(weights, biases, layer_name):\n",
0974 " # Print biases\n",
0975 " print(f\"ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_{layer_name}[{len(biases)}] = {{\")\n",
0976 " print(\", \".join(f\"{b:.7f}f\" for b in biases) + \" };\")\n",
0977 " print()\n",
0978 "\n",
0979 " # Print weights\n",
0980 " print(f\"ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_{layer_name}[{len(weights[0])}][{len(weights)}] = {{\")\n",
0981 " for row in weights.T:\n",
0982 " formatted_row = \", \".join(f\"{w:.7f}f\" for w in row)\n",
0983 " print(f\"{{ {formatted_row} }},\")\n",
0984 " print(\"};\")\n",
0985 " print()\n",
0986 "\n",
0987 "def print_model_weights_biases(model):\n",
0988 " # Make sure the model is in evaluation mode\n",
0989 " model.eval()\n",
0990 "\n",
0991 " # Iterate through all named modules in the model\n",
0992 " for name, module in model.named_modules():\n",
0993 " # Check if the module is a linear layer\n",
0994 " if isinstance(module, nn.Linear):\n",
0995 " # Get weights and biases\n",
0996 " weights = module.weight.data.cpu().numpy()\n",
0997 " biases = module.bias.data.cpu().numpy()\n",
0998 "\n",
0999 " # Print formatted weights and biases\n",
1000 " print_formatted_weights_biases(weights, biases, name.replace('.', '_'))\n",
1001 "\n",
1002 "print_model_weights_biases(model)\n"
1003 ]
1004 },
1005 {
1006 "cell_type": "code",
1007 "execution_count": 9,
1008 "metadata": {},
1009 "outputs": [],
1010 "source": [
1011 "# Ensure input_features_tensor is moved to the appropriate device\n",
1012 "input_features_tensor = input_features_tensor.to(device)\n",
1013 "\n",
1014 "# Make predictions\n",
1015 "with torch.no_grad():\n",
1016 " model.eval()\n",
1017 " outputs = model(input_features_tensor)\n",
1018 " predictions = outputs.squeeze().cpu().numpy()\n",
1019 "\n",
1020 "full_tracks = (np.concatenate(branches['t3_pMatched']) > 0.95)\n",
1021 "\n",
1022 "t3_pt = np.concatenate(branches['t3_radius']) * 2 * (2.99792458e-3 * 3.8) / 2"
1023 ]
1024 },
1025 {
1026 "cell_type": "code",
1027 "execution_count": 10,
1028 "metadata": {},
1029 "outputs": [
1030 {
1031 "name": "stdout",
1032 "output_type": "stream",
1033 "text": [
1034 "Eta bin 0.00-0.25: 9409714 fakes, 313231 true Prompt\n",
1035 "Eta bin 0.25-0.50: 9242595 fakes, 323051 true Prompt\n",
1036 "Eta bin 0.50-0.75: 7849380 fakes, 410185 true Prompt\n",
1037 "Eta bin 0.75-1.00: 4293980 fakes, 322065 true Prompt\n",
1038 "Eta bin 1.00-1.25: 4343023 fakes, 374215 true Prompt\n",
1039 "Eta bin 1.25-1.50: 2725728 fakes, 351420 true Prompt\n",
1040 "Eta bin 1.50-1.75: 1368266 fakes, 425819 true Prompt\n",
1041 "Eta bin 1.75-2.00: 1413754 fakes, 467604 true Prompt\n",
1042 "Eta bin 2.00-2.25: 448439 fakes, 419450 true Prompt\n",
1043 "Eta bin 2.25-2.50: 124212 fakes, 247704 true Prompt\n"
1044 ]
1045 },
1046 {
1047 "data": {
1048 "text/plain": [
1049 "<Figure size 2000x800 with 3 Axes>"
1050 ]
1051 },
1052 "metadata": {},
1053 "output_type": "display_data"
1054 },
1055 {
1056 "name": "stdout",
1057 "output_type": "stream",
1058 "text": [
1059 "\n",
1060 "Prompt tracks, pt: 0.0 to 5.0 GeV\n",
1061 "Number of true prompt tracks: 3654744\n",
1062 "Number of fake tracks in pt bin: 41219091\n",
1063 "\n",
1064 "80% Retention Cut Values: {0.6326, 0.6415, 0.6506, 0.6526, 0.5321, 0.5651, 0.5747, 0.5765, 0.6243, 0.6320} Mean: 0.6082\n",
1065 "80% Cut Fake Rejections: {99.1, 99.1, 99.1, 99.2, 97.4, 97.1, 96.8, 96.8, 95.6, 92.4} Mean: 97.3%\n",
1066 "\n",
1067 "90% Retention Cut Values: {0.4957, 0.5052, 0.5201, 0.5340, 0.4275, 0.4708, 0.4890, 0.4932, 0.5400, 0.5449} Mean: 0.502\n",
1068 "90% Cut Fake Rejections: {98.5, 98.6, 98.6, 98.7, 95.5, 95.4, 95.0, 95.0, 93.1, 88.9} Mean: 95.7%\n",
1069 "\n",
1070 "93% Retention Cut Values: {0.3874, 0.3978, 0.4201, 0.4502, 0.3750, 0.4117, 0.4394, 0.4466, 0.4916, 0.4899} Mean: 0.431\n",
1071 "93% Cut Fake Rejections: {98.1, 98.2, 98.2, 98.3, 94.5, 94.3, 94.0, 94.1, 91.8, 86.8} Mean: 94.8%\n",
1072 "\n",
1073 "96% Retention Cut Values: {0.1332, 0.1616, 0.2094, 0.2898, 0.2909, 0.3109, 0.3561, 0.3626, 0.4016, 0.3775} Mean: 0.2894\n",
1074 "96% Cut Fake Rejections: {96.3, 96.7, 96.9, 97.1, 92.5, 92.4, 92.2, 92.4, 89.6, 82.9} Mean: 92.9%\n",
1075 "\n",
1076 "97% Retention Cut Values: {0.0894, 0.1137, 0.1595, 0.2333, 0.2507, 0.2569, 0.3081, 0.3120, 0.3466, 0.2932} Mean: 0.2363\n",
1077 "97% Cut Fake Rejections: {94.6, 95.4, 95.8, 95.9, 91.4, 91.1, 91.2, 91.4, 88.3, 79.7} Mean: 91.5%\n",
1078 "\n",
1079 "98% Retention Cut Values: {0.0595, 0.0742, 0.1129, 0.1736, 0.1919, 0.1824, 0.2337, 0.2285, 0.2529, 0.1681} Mean: 0.1678\n",
1080 "98% Cut Fake Rejections: {91.9, 93.1, 93.9, 94.0, 89.5, 89.0, 89.3, 89.5, 85.8, 73.7} Mean: 89.0%\n",
1081 "\n",
1082 "99% Retention Cut Values: {0.0319, 0.0371, 0.0493, 0.0943, 0.0923, 0.0767, 0.1061, 0.1041, 0.1124, 0.0762} Mean: 0.078\n",
1083 "99% Cut Fake Rejections: {86.5, 88.4, 88.8, 89.8, 84.8, 83.9, 84.3, 84.5, 79.6, 62.8} Mean: 83.3%\n",
1084 "\n",
1085 "99.5% Retention Cut Values: {0.0149, 0.0168, 0.0242, 0.0402, 0.0389, 0.0365, 0.0527, 0.0546, 0.0615, 0.0366} Mean: 0.0377\n",
1086 "99.5% Cut Fake Rejections: {78.7, 81.5, 83.3, 84.1, 79.2, 78.5, 78.6, 79.3, 73.6, 49.7} Mean: 76.7%\n",
1087 "Eta bin 0.00-0.25: 9409714 fakes, 43220 true Displaced\n",
1088 "Eta bin 0.25-0.50: 9242595 fakes, 47035 true Displaced\n",
1089 "Eta bin 0.50-0.75: 7849380 fakes, 62690 true Displaced\n",
1090 "Eta bin 0.75-1.00: 4293980 fakes, 52590 true Displaced\n",
1091 "Eta bin 1.00-1.25: 4343023 fakes, 62242 true Displaced\n",
1092 "Eta bin 1.25-1.50: 2725728 fakes, 59777 true Displaced\n",
1093 "Eta bin 1.50-1.75: 1368266 fakes, 76741 true Displaced\n",
1094 "Eta bin 1.75-2.00: 1413754 fakes, 90436 true Displaced\n",
1095 "Eta bin 2.00-2.25: 448439 fakes, 73564 true Displaced\n",
1096 "Eta bin 2.25-2.50: 124212 fakes, 43525 true Displaced\n"
1097 ]
1098 },
1099 {
1100 "data": {
1101 "text/plain": [
1102 "<Figure size 2000x800 with 3 Axes>"
1103 ]
1104 },
1105 "metadata": {},
1106 "output_type": "display_data"
1107 },
1108 {
1109 "name": "stdout",
1110 "output_type": "stream",
1111 "text": [
1112 "\n",
1113 "Displaced tracks, pt: 0.0 to 5.0 GeV\n",
1114 "Number of true displaced tracks: 611820\n",
1115 "Number of fake tracks in pt bin: 41219091\n",
1116 "\n",
1117 "80% Retention Cut Values: {0.2362, 0.2377, 0.2496, 0.2579, 0.2940, 0.3057, 0.3273, 0.3224, 0.2944, 0.2890} Mean: 0.2814\n",
1118 "80% Cut Fake Rejections: {83.1, 81.5, 78.7, 76.4, 75.6, 68.3, 54.0, 49.9, 34.7, 18.3} Mean: 62.1%\n",
1119 "\n",
1120 "90% Retention Cut Values: {0.1809, 0.1855, 0.1998, 0.2098, 0.2173, 0.2267, 0.2391, 0.2507, 0.2356, 0.2317} Mean: 0.2177\n",
1121 "90% Cut Fake Rejections: {79.7, 78.0, 75.4, 73.2, 70.1, 61.5, 46.5, 43.7, 29.3, 12.4} Mean: 57.0%\n",
1122 "\n",
1123 "93% Retention Cut Values: {0.1527, 0.1622, 0.1814, 0.1933, 0.1952, 0.1889, 0.2041, 0.2280, 0.2216, 0.2166} Mean: 0.1944\n",
1124 "93% Cut Fake Rejections: {77.6, 76.1, 74.0, 72.1, 68.4, 58.2, 43.5, 41.8, 28.0, 10.8} Mean: 55.0%\n",
1125 "\n",
1126 "96% Retention Cut Values: {0.1133, 0.1274, 0.1514, 0.1662, 0.1657, 0.1576, 0.1601, 0.2030, 0.2068, 0.2028} Mean: 0.1654\n",
1127 "96% Cut Fake Rejections: {73.9, 72.8, 71.5, 70.0, 65.9, 55.1, 39.2, 39.6, 26.6, 9.6} Mean: 52.4%\n",
1128 "\n",
1129 "97% Retention Cut Values: {0.0956, 0.1131, 0.1362, 0.1509, 0.1549, 0.1459, 0.1479, 0.1924, 0.2012, 0.1980} Mean: 0.1536\n",
1130 "97% Cut Fake Rejections: {71.7, 71.2, 70.0, 68.7, 65.0, 53.9, 37.8, 38.7, 26.1, 9.1} Mean: 51.2%\n",
1131 "\n",
1132 "98% Retention Cut Values: {0.0686, 0.0874, 0.1165, 0.1298, 0.1400, 0.1311, 0.1327, 0.1773, 0.1954, 0.1922} Mean: 0.1371\n",
1133 "98% Cut Fake Rejections: {67.4, 67.7, 68.0, 66.8, 63.6, 52.2, 36.0, 37.4, 25.6, 8.6} Mean: 49.3%\n",
1134 "\n",
1135 "99% Retention Cut Values: {0.0334, 0.0504, 0.0748, 0.0994, 0.1128, 0.1123, 0.1118, 0.1525, 0.1867, 0.1847} Mean: 0.1119\n",
1136 "99% Cut Fake Rejections: {57.9, 60.3, 62.3, 63.5, 60.8, 49.8, 33.2, 35.1, 24.8, 7.9} Mean: 45.6%\n",
1137 "\n",
1138 "99.5% Retention Cut Values: {0.0155, 0.0237, 0.0414, 0.0560, 0.0808, 0.0901, 0.0960, 0.1335, 0.1789, 0.1791} Mean: 0.0895\n",
1139 "99.5% Cut Fake Rejections: {47.6, 50.3, 54.7, 56.5, 56.7, 46.6, 30.8, 33.2, 24.1, 7.5} Mean: 40.8%\n",
1140 "Eta bin 0.00-0.25: 2012526 fakes, 6249 true Prompt\n",
1141 "Eta bin 0.25-0.50: 1972121 fakes, 6496 true Prompt\n",
1142 "Eta bin 0.50-0.75: 1704510 fakes, 6894 true Prompt\n",
1143 "Eta bin 0.75-1.00: 930629 fakes, 5318 true Prompt\n",
1144 "Eta bin 1.00-1.25: 861320 fakes, 9397 true Prompt\n",
1145 "Eta bin 1.25-1.50: 523329 fakes, 14695 true Prompt\n",
1146 "Eta bin 1.50-1.75: 246635 fakes, 24265 true Prompt\n",
1147 "Eta bin 1.75-2.00: 250585 fakes, 15787 true Prompt\n",
1148 "Eta bin 2.00-2.25: 86204 fakes, 6652 true Prompt\n",
1149 "Eta bin 2.25-2.50: 22080 fakes, 3385 true Prompt\n"
1150 ]
1151 },
1152 {
1153 "data": {
1154 "text/plain": [
1155 "<Figure size 2000x800 with 3 Axes>"
1156 ]
1157 },
1158 "metadata": {},
1159 "output_type": "display_data"
1160 },
1161 {
1162 "name": "stdout",
1163 "output_type": "stream",
1164 "text": [
1165 "\n",
1166 "Prompt tracks, pt: 5.0 to inf GeV\n",
1167 "Number of true prompt tracks: 99138\n",
1168 "Number of fake tracks in pt bin: 8609939\n",
1169 "\n",
1170 "80% Retention Cut Values: {0.1353, 0.1569, 0.2083, 0.2407, 0.2910, 0.3552, 0.4221, 0.4197, 0.3174, 0.2681} Mean: 0.2815\n",
1171 "80% Cut Fake Rejections: {98.8, 99.0, 99.3, 99.4, 96.1, 94.8, 93.5, 93.7, 88.5, 74.4} Mean: 93.7%\n",
1172 "\n",
1173 "90% Retention Cut Values: {0.0302, 0.0415, 0.0994, 0.1791, 0.1960, 0.2467, 0.3227, 0.3242, 0.2367, 0.2187} Mean: 0.1895\n",
1174 "90% Cut Fake Rejections: {90.2, 94.5, 98.1, 99.0, 93.4, 91.4, 89.8, 89.7, 80.1, 63.8} Mean: 89.0%\n",
1175 "\n",
1176 "93% Retention Cut Values: {0.0188, 0.0261, 0.0502, 0.1487, 0.1606, 0.2022, 0.2772, 0.2797, 0.2182, 0.2063} Mean: 0.1588\n",
1177 "93% Cut Fake Rejections: {82.0, 89.1, 94.9, 98.7, 91.8, 89.5, 87.9, 87.3, 77.5, 60.9} Mean: 85.9%\n",
1178 "\n",
1179 "96% Retention Cut Values: {0.0095, 0.0138, 0.0243, 0.0620, 0.1096, 0.1450, 0.2184, 0.2276, 0.1949, 0.1769} Mean: 0.1182\n",
1180 "96% Cut Fake Rejections: {67.7, 78.4, 86.6, 95.6, 88.5, 86.0, 84.6, 83.8, 74.2, 54.1} Mean: 79.9%\n",
1181 "\n",
1182 "97% Retention Cut Values: {0.0079, 0.0105, 0.0185, 0.0361, 0.0883, 0.1242, 0.1943, 0.2106, 0.1600, 0.1354} Mean: 0.0986\n",
1183 "97% Cut Fake Rejections: {63.7, 73.1, 82.1, 91.3, 86.5, 84.3, 83.0, 82.5, 69.0, 42.8} Mean: 75.8%\n",
1184 "\n",
1185 "98% Retention Cut Values: {0.0057, 0.0077, 0.0119, 0.0212, 0.0684, 0.0902, 0.1624, 0.1814, 0.1158, 0.1037} Mean: 0.0769\n",
1186 "98% Cut Fake Rejections: {56.9, 66.9, 74.3, 84.8, 84.0, 80.5, 80.6, 79.9, 61.0, 32.8} Mean: 70.2%\n",
1187 "\n",
1188 "99% Retention Cut Values: {0.0027, 0.0045, 0.0074, 0.0119, 0.0325, 0.0393, 0.1192, 0.1170, 0.0683, 0.0819} Mean: 0.0485\n",
1189 "99% Cut Fake Rejections: {43.1, 56.7, 65.6, 76.3, 75.4, 69.5, 76.2, 72.6, 49.8, 25.6} Mean: 61.1%\n",
1190 "\n",
1191 "99.5% Retention Cut Values: {0.0015, 0.0027, 0.0038, 0.0054, 0.0153, 0.0219, 0.0820, 0.0657, 0.0373, 0.0669} Mean: 0.0302\n",
1192 "99.5% Cut Fake Rejections: {34.3, 48.3, 54.4, 63.1, 64.6, 60.9, 70.6, 63.0, 39.3, 20.4} Mean: 51.9%\n",
1193 "Eta bin 0.00-0.25: 2012526 fakes, 2764 true Displaced\n",
1194 "Eta bin 0.25-0.50: 1972121 fakes, 2581 true Displaced\n",
1195 "Eta bin 0.50-0.75: 1704510 fakes, 2477 true Displaced\n",
1196 "Eta bin 0.75-1.00: 930629 fakes, 2122 true Displaced\n",
1197 "Eta bin 1.00-1.25: 861320 fakes, 2780 true Displaced\n",
1198 "Eta bin 1.25-1.50: 523329 fakes, 3481 true Displaced\n",
1199 "Eta bin 1.50-1.75: 246635 fakes, 4701 true Displaced\n",
1200 "Eta bin 1.75-2.00: 250585 fakes, 3009 true Displaced\n",
1201 "Eta bin 2.00-2.25: 86204 fakes, 1579 true Displaced\n",
1202 "Eta bin 2.25-2.50: 22080 fakes, 881 true Displaced\n"
1203 ]
1204 },
1205 {
1206 "data": {
1207 "text/plain": [
1208 "<Figure size 2000x800 with 3 Axes>"
1209 ]
1210 },
1211 "metadata": {},
1212 "output_type": "display_data"
1213 },
1214 {
1215 "name": "stdout",
1216 "output_type": "stream",
1217 "text": [
1218 "\n",
1219 "Displaced tracks, pt: 5.0 to inf GeV\n",
1220 "Number of true displaced tracks: 26375\n",
1221 "Number of fake tracks in pt bin: 8609939\n",
1222 "\n",
1223 "80% Retention Cut Values: {0.5019, 0.4954, 0.5197, 0.5256, 0.4161, 0.3509, 0.3778, 0.3667, 0.4388, 0.4915} Mean: 0.4484\n",
1224 "80% Cut Fake Rejections: {98.6, 98.6, 98.7, 98.7, 97.0, 93.3, 87.6, 86.9, 89.0, 88.0} Mean: 93.6%\n",
1225 "\n",
1226 "90% Retention Cut Values: {0.3265, 0.3967, 0.4493, 0.4656, 0.3459, 0.2910, 0.3486, 0.3356, 0.3693, 0.4448} Mean: 0.3773\n",
1227 "90% Cut Fake Rejections: {97.6, 98.1, 98.2, 98.3, 95.8, 90.8, 85.4, 84.0, 83.3, 82.4} Mean: 91.4%\n",
1228 "\n",
1229 "93% Retention Cut Values: {0.1804, 0.2502, 0.4033, 0.4375, 0.2958, 0.2556, 0.3353, 0.3242, 0.3405, 0.4194} Mean: 0.3242\n",
1230 "93% Cut Fake Rejections: {96.1, 97.0, 97.9, 98.2, 94.6, 89.1, 84.5, 83.1, 80.3, 78.8} Mean: 90.0%\n",
1231 "\n",
1232 "96% Retention Cut Values: {0.0464, 0.0366, 0.2521, 0.3317, 0.2324, 0.2021, 0.3165, 0.2992, 0.3096, 0.3568} Mean: 0.2383\n",
1233 "96% Cut Fake Rejections: {90.2, 88.7, 96.8, 97.5, 92.9, 86.3, 83.3, 81.2, 77.1, 69.3} Mean: 86.3%\n",
1234 "\n",
1235 "97% Retention Cut Values: {0.0270, 0.0225, 0.1590, 0.1765, 0.2111, 0.1570, 0.3039, 0.2869, 0.2999, 0.2103} Mean: 0.1854\n",
1236 "97% Cut Fake Rejections: {85.2, 83.8, 95.5, 96.0, 92.1, 83.3, 82.6, 80.4, 76.2, 43.0} Mean: 81.8%\n",
1237 "\n",
1238 "98% Retention Cut Values: {0.0168, 0.0160, 0.0687, 0.0392, 0.1494, 0.1121, 0.2684, 0.2642, 0.2935, 0.1370} Mean: 0.1365\n",
1239 "98% Cut Fake Rejections: {79.1, 79.4, 92.1, 88.6, 89.4, 79.3, 80.6, 78.7, 75.5, 26.5} Mean: 76.9%\n",
1240 "\n",
1241 "99% Retention Cut Values: {0.0091, 0.0075, 0.0350, 0.0213, 0.0435, 0.0676, 0.1957, 0.1649, 0.1080, 0.1046} Mean: 0.0757\n",
1242 "99% Cut Fake Rejections: {69.6, 67.9, 87.1, 82.3, 77.4, 73.2, 76.2, 70.3, 50.1, 19.1} Mean: 67.3%\n",
1243 "\n",
1244 "99.5% Retention Cut Values: {0.0059, 0.0021, 0.0151, 0.0116, 0.0222, 0.0309, 0.1230, 0.1205, 0.0779, 0.0930} Mean: 0.0502\n",
1245 "99.5% Cut Fake Rejections: {62.0, 47.0, 76.9, 74.2, 69.0, 62.6, 69.7, 64.8, 43.9, 16.5} Mean: 58.7%\n"
1246 ]
1247 }
1248 ],
1249 "source": [
1250 "import numpy as np\n",
1251 "from matplotlib import pyplot as plt\n",
1252 "from matplotlib.colors import LogNorm\n",
1253 "import torch\n",
1254 "\n",
1255 "# Ensure input_features_tensor is on the right device\n",
1256 "input_features_tensor = input_features_tensor.to(device)\n",
1257 "\n",
1258 "# Get model predictions\n",
1259 "with torch.no_grad():\n",
1260 " model.eval()\n",
1261 " outputs = model(input_features_tensor)\n",
1262 " predictions = outputs.cpu().numpy() # Shape will be [n_samples, 3]\n",
1263 "\n",
1264 "# Get track information\n",
1265 "t3_pt = np.concatenate(branches['t3_radius']) * 2 * (2.99792458e-3 * 3.8) / 2\n",
1266 "\n",
1267 "def plot_for_pt_bin(pt_min, pt_max, percentiles, eta_bin_edges, t3_pt, predictions, t3_sim_vxy, eta_list):\n",
1268 " \"\"\"\n",
1269 " Calculate and plot cut values for specified percentiles in a given pt bin, separately for prompt and displaced tracks\n",
1270 " \"\"\"\n",
1271 " # Filter data based on pt bin\n",
1272 " pt_mask = (t3_pt > pt_min) & (t3_pt <= pt_max)\n",
1273 " \n",
1274 " # Get absolute eta values for all tracks in pt bin\n",
1275 " abs_eta = np.abs(eta_list[0][pt_mask])\n",
1276 " \n",
1277 " # Get predictions for all tracks in pt bin\n",
1278 " pred_filtered = predictions[pt_mask]\n",
1279 " \n",
1280 " # Get track types using pMatched and t3_sim_vxy\n",
1281 " matched = (np.concatenate(branches['t3_pMatched']) > 0.95)[pt_mask]\n",
1282 " fake_tracks = (np.concatenate(branches['t3_pMatched']) < 0.75)[pt_mask]\n",
1283 " true_displaced = (t3_sim_vxy[pt_mask] > 0.1) & matched\n",
1284 " true_prompt = ~(t3_sim_vxy[pt_mask] > 0.1) & matched\n",
1285 " \n",
1286 " # Separate plots for prompt and displaced tracks\n",
1287 " for track_type, true_mask, pred_idx, title_suffix in [\n",
1288 " (\"Prompt\", true_prompt, 1, \"Prompt Real Tracks\"),\n",
1289 " (\"Displaced\", true_displaced, 2, \"Displaced Real Tracks\")\n",
1290 " ]:\n",
1291 " # Dictionaries to store values\n",
1292 " cut_values = {p: [] for p in percentiles}\n",
1293 " fake_rejections = {p: [] for p in percentiles}\n",
1294 " \n",
1295 " # Get probabilities for this class\n",
1296 " probs = pred_filtered[:, pred_idx]\n",
1297 " \n",
1298 " # Create two side-by-side plots\n",
1299 " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))\n",
1300 " \n",
1301 " # Plot probability distribution (only for true tracks of this type)\n",
1302 " h = ax1.hist2d(abs_eta[true_mask], \n",
1303 " probs[true_mask], \n",
1304 " bins=[eta_bin_edges, 50], \n",
1305 " norm=LogNorm())\n",
1306 " plt.colorbar(h[3], ax=ax1, label='Counts')\n",
1307 " \n",
1308 " # For each eta bin\n",
1309 " bin_centers = []\n",
1310 " for i in range(len(eta_bin_edges) - 1):\n",
1311 " eta_min, eta_max = eta_bin_edges[i], eta_bin_edges[i+1]\n",
1312 " bin_center = (eta_min + eta_max) / 2\n",
1313 " bin_centers.append(bin_center)\n",
1314 " \n",
1315 " # Get tracks in this eta bin\n",
1316 " eta_mask = (abs_eta >= eta_min) & (abs_eta < eta_max)\n",
1317 " \n",
1318 " # True tracks of this type in this bin\n",
1319 " true_type_mask = eta_mask & true_mask\n",
1320 " # Fake tracks in this bin\n",
1321 " fake_mask = eta_mask & fake_tracks\n",
1322 " \n",
1323 " print(f\"Eta bin {eta_min:.2f}-{eta_max:.2f}: {np.sum(fake_mask)} fakes, {np.sum(true_type_mask)} true {track_type}\")\n",
1324 " \n",
1325 " if np.sum(true_type_mask) > 0: # If we have true tracks in this bin\n",
1326 " for percentile in percentiles:\n",
1327 " # Calculate cut value to keep desired percentage of true tracks\n",
1328 " cut_value = np.percentile(probs[true_type_mask], 100 - percentile)\n",
1329 " cut_values[percentile].append(cut_value)\n",
1330 " \n",
1331 " # Calculate fake rejection for this cut\n",
1332 " if np.sum(fake_mask) > 0:\n",
1333 " fake_rej = 100 * np.mean(probs[fake_mask] < cut_value)\n",
1334 " fake_rejections[percentile].append(fake_rej)\n",
1335 " else:\n",
1336 " fake_rejections[percentile].append(np.nan)\n",
1337 " else:\n",
1338 " for percentile in percentiles:\n",
1339 " cut_values[percentile].append(np.nan)\n",
1340 " fake_rejections[percentile].append(np.nan)\n",
1341 " \n",
1342 " # Plot cut values and fake rejections\n",
1343 " colors = plt.cm.rainbow(np.linspace(0, 1, len(percentiles)))\n",
1344 " bin_centers = np.array(bin_centers)\n",
1345 " \n",
1346 " for (percentile, color) in zip(percentiles, colors):\n",
1347 " values = np.array(cut_values[percentile])\n",
1348 " mask = ~np.isnan(values)\n",
1349 " if np.any(mask):\n",
1350 " # Plot cut values\n",
1351 " ax1.plot(bin_centers[mask], values[mask], '-', color=color, marker='o',\n",
1352 " label=f'{percentile}% Retention Cut')\n",
1353 " # Plot fake rejections\n",
1354 " rej_values = np.array(fake_rejections[percentile])\n",
1355 " ax2.plot(bin_centers[mask], rej_values[mask], '-', color=color, marker='o',\n",
1356 " label=f'{percentile}% Cut')\n",
1357 " \n",
1358 " # Set plot labels and titles\n",
1359 " ax1.set_xlabel(\"Absolute Eta\")\n",
1360 " ax1.set_ylabel(f\"DNN {track_type} Probability\")\n",
1361 " ax1.set_title(f\"DNN Score vs Eta ({title_suffix})\\npt: {pt_min:.1f} to {pt_max:.1f} GeV\")\n",
1362 " ax1.legend()\n",
1363 " ax1.grid(True, alpha=0.3)\n",
1364 " \n",
1365 " ax2.set_xlabel(\"Absolute Eta\")\n",
1366 " ax2.set_ylabel(\"Fake Rejection (%)\")\n",
1367 " ax2.set_title(f\"Fake Rejection vs Eta\\npt: {pt_min:.1f} to {pt_max:.1f} GeV\")\n",
1368 " ax2.legend()\n",
1369 " ax2.grid(True, alpha=0.3)\n",
1370 " ax2.set_ylim(0, 100)\n",
1371 " \n",
1372 " plt.tight_layout()\n",
1373 " plt.show()\n",
1374 " \n",
1375 " # Print statistics\n",
1376 " print(f\"\\n{track_type} tracks, pt: {pt_min:.1f} to {pt_max:.1f} GeV\")\n",
1377 " print(f\"Number of true {track_type.lower()} tracks: {np.sum(true_mask)}\")\n",
1378 " print(f\"Number of fake tracks in pt bin: {np.sum(fake_tracks)}\")\n",
1379 " \n",
1380 " for percentile in percentiles:\n",
1381 " print(f\"\\n{percentile}% Retention Cut Values:\",\n",
1382 " '{' + ', '.join(f\"{x:.4f}\" if not np.isnan(x) else 'nan' for x in cut_values[percentile]) + '}',\n",
1383 " f\"Mean: {np.round(np.nanmean(cut_values[percentile]), 4)}\")\n",
1384 " print(f\"{percentile}% Cut Fake Rejections:\",\n",
1385 " '{' + ', '.join(f\"{x:.1f}\" if not np.isnan(x) else 'nan' for x in fake_rejections[percentile]) + '}',\n",
1386 " f\"Mean: {np.round(np.nanmean(fake_rejections[percentile]), 1)}%\")\n",
1387 "\n",
1388 "def analyze_pt_bins(pt_bins, percentiles, eta_bin_edges, t3_pt, predictions, t3_sim_vxy, eta_list):\n",
1389 " \"\"\"\n",
1390 " Analyze and plot for multiple pt bins and percentiles\n",
1391 " \"\"\"\n",
1392 " for i in range(len(pt_bins) - 1):\n",
1393 " plot_for_pt_bin(pt_bins[i], pt_bins[i + 1], percentiles, eta_bin_edges,\n",
1394 " t3_pt, predictions, t3_sim_vxy, eta_list)\n",
1395 "\n",
1396 "# Run the analysis with same parameters as before\n",
1397 "percentiles = [80, 90, 93, 96, 97, 98, 99, 99.5]\n",
1398 "pt_bins = [0, 5, np.inf]\n",
1399 "eta_bin_edges = np.arange(0, 2.75, 0.25)\n",
1400 "\n",
1401 "analyze_pt_bins(\n",
1402 " pt_bins=pt_bins,\n",
1403 " percentiles=percentiles,\n",
1404 " eta_bin_edges=eta_bin_edges,\n",
1405 " t3_pt=t3_pt,\n",
1406 " predictions=predictions,\n",
1407 " t3_sim_vxy=np.concatenate(branches['t3_sim_vxy']),\n",
1408 " eta_list=eta_list\n",
1409 ")"
1410 ]
1411 },
1412 {
1413 "cell_type": "code",
1414 "execution_count": null,
1415 "metadata": {},
1416 "outputs": [],
1417 "source": []
1418 }
1419 ],
1420 "metadata": {
1421 "kernelspec": {
1422 "display_name": "analysisenv",
1423 "language": "python",
1424 "name": "python3"
1425 },
1426 "language_info": {
1427 "codemirror_mode": {
1428 "name": "ipython",
1429 "version": 3
1430 },
1431 "file_extension": ".py",
1432 "mimetype": "text/x-python",
1433 "name": "python",
1434 "nbconvert_exporter": "python",
1435 "pygments_lexer": "ipython3",
1436 "version": "3.11.7"
1437 }
1438 },
1439 "nbformat": 4,
1440 "nbformat_minor": 2
1441 }