Back to home page

Project CMSSW displayed by LXR

 
 

    


Warning, /RecoTracker/LSTCore/standalone/analysis/DNN/train_T3_DNN.ipynb is written in an unsupported language. File is not indexed.

0001 {
0002   "cells": [
0003     {
0004       "cell_type": "code",
0005       "execution_count": 1,
0006       "metadata": {},
0007       "outputs": [],
0008       "source": [
0009         "import os\n",
0010         "import uproot\n",
0011         "import numpy as np\n",
0012         "\n",
0013         "def load_root_file(file_path, branches=None, print_branches=False):\n",
0014         "    all_branches = {}\n",
0015         "    with uproot.open(file_path) as file:\n",
0016         "        tree = file[\"tree\"]\n",
0017         "        # Load all ROOT branches into array if not specified\n",
0018         "        if branches is None:\n",
0019         "            branches = tree.keys()\n",
0020         "        # Option to print the branch names\n",
0021         "        if print_branches:\n",
0022         "            print(\"Branches:\", tree.keys())\n",
0023         "        # Each branch is added to the dictionary\n",
0024         "        for branch in branches:\n",
0025         "            try:\n",
0026         "                all_branches[branch] = (tree[branch].array(library=\"np\"))\n",
0027         "            except uproot.KeyInFileError as e:\n",
0028         "                print(f\"KeyInFileError: {e}\")\n",
0029         "        # Number of events in file\n",
0030         "        all_branches['event'] = tree.num_entries\n",
0031         "    return all_branches\n",
0032         "\n",
0033         "branches_list = [\n",
0034         "    # Core T3 properties from TripletsSoA\n",
0035         "    't3_betaIn',\n",
0036         "    't3_centerX',\n",
0037         "    't3_centerY',\n",
0038         "    't3_radius',\n",
0039         "    't3_partOfPT5',\n",
0040         "    't3_partOfT5',\n",
0041         "    't3_partOfPT3',\n",
0042         "    't3_layer_binary',\n",
0043         "    't3_pMatched',\n",
0044         "    't3_matched_simIdx',\n",
0045         "    't3_sim_vxy',\n",
0046         "    't3_sim_vz'\n",
0047         "]\n",
0048         "\n",
0049         "# Hit-dependent branches\n",
0050         "suffixes = ['r', 'z', 'eta', 'phi', 'layer']\n",
0051         "branches_list += [f't3_hit_{i}_{suffix}' for i in [0, 1, 2, 3, 4, 5] for suffix in suffixes]\n",
0052         "\n",
0053         "file_path = \"600_t3_dnn_relval_fix.root\"\n",
0054         "branches = load_root_file(file_path, branches_list)"
0055       ]
0056     },
0057     {
0058       "cell_type": "code",
0059       "execution_count": 2,
0060       "metadata": {},
0061       "outputs": [
0062         {
0063           "name": "stdout",
0064           "output_type": "stream",
0065           "text": [
0066             "Z max: 224.14950561523438, R max: 98.93299102783203, Eta max: 2.5\n"
0067           ]
0068         }
0069       ],
0070       "source": [
0071         "z_max = np.max([np.max(event) for event in branches[f't3_hit_3_z']])\n",
0072         "r_max = np.max([np.max(event) for event in branches[f't3_hit_3_r']])\n",
0073         "eta_max = 2.5\n",
0074         "phi_max = np.pi\n",
0075         "\n",
0076         "print(f'Z max: {z_max}, R max: {r_max}, Eta max: {eta_max}')"
0077       ]
0078     },
0079     {
0080       "cell_type": "code",
0081       "execution_count": 3,
0082       "metadata": {},
0083       "outputs": [],
0084       "source": [
0085         "def delta_phi(phi1, phi2):\n",
0086         "    delta = phi1 - phi2\n",
0087         "    delta = np.where(delta > np.pi, delta - 2*np.pi, delta)\n",
0088         "    delta = np.where(delta < -np.pi, delta + 2*np.pi, delta)\n",
0089         "    return delta\n",
0090         "\n",
0091         "n_events = branches['event']\n",
0092         "\n",
0093         "all_eta0 = np.abs(np.concatenate([branches['t3_hit_0_eta'][evt] for evt in range(n_events)]))\n",
0094         "all_eta2 = np.abs(np.concatenate([branches['t3_hit_2_eta'][evt] for evt in range(n_events)]))\n",
0095         "all_eta4 = np.abs(np.concatenate([branches['t3_hit_4_eta'][evt] for evt in range(n_events)]))\n",
0096         "\n",
0097         "all_phi0 = np.concatenate([branches['t3_hit_0_phi'][evt] for evt in range(n_events)])\n",
0098         "all_phi2 = np.concatenate([branches['t3_hit_2_phi'][evt] for evt in range(n_events)])\n",
0099         "all_phi4 = np.concatenate([branches['t3_hit_4_phi'][evt] for evt in range(n_events)])\n",
0100         "\n",
0101         "all_z0 = np.abs(np.concatenate([branches['t3_hit_0_z'][evt] for evt in range(n_events)]))\n",
0102         "all_z2 = np.abs(np.concatenate([branches['t3_hit_2_z'][evt] for evt in range(n_events)]))\n",
0103         "all_z4 = np.abs(np.concatenate([branches['t3_hit_4_z'][evt] for evt in range(n_events)]))\n",
0104         "\n",
0105         "all_r0 = np.concatenate([branches['t3_hit_0_r'][evt] for evt in range(n_events)])\n",
0106         "all_r2 = np.concatenate([branches['t3_hit_2_r'][evt] for evt in range(n_events)])\n",
0107         "all_r4 = np.concatenate([branches['t3_hit_4_r'][evt] for evt in range(n_events)])\n",
0108         "\n",
0109         "all_radius = np.concatenate([branches['t3_radius'][evt] for evt in range(n_events)])\n",
0110         "all_betaIn = np.concatenate([branches['t3_betaIn'][evt] for evt in range(n_events)])\n",
0111         "\n",
0112         "features = np.array([\n",
0113         "    all_eta0 / eta_max,                      # Hit 0 eta\n",
0114         "    np.abs(all_phi0) / phi_max,              # Hit 0 phi\n",
0115         "    all_z0 / z_max,                          # Hit 0 z\n",
0116         "    all_r0 / r_max,                          # Hit 0 r\n",
0117         "    (all_eta2 - all_eta0),                   # Difference in eta: hit2 - hit0\n",
0118         "    delta_phi(all_phi2, all_phi0) / phi_max, # Difference in phi: hit2 - hit0\n",
0119         "    (all_z2 - all_z0) / z_max,               # Difference in z: hit2 - hit0\n",
0120         "    (all_r2 - all_r0) / r_max,               # Difference in r: hit2 - hit0\n",
0121         "    (all_eta4 - all_eta2),                   # Difference in eta: hit4 - hit2\n",
0122         "    delta_phi(all_phi4, all_phi2) / phi_max, # Difference in phi: hit4 - hit2\n",
0123         "    (all_z4 - all_z2) / z_max,               # Difference in z: hit4 - hit2\n",
0124         "    (all_r4 - all_r2) / r_max,               # Difference in r: hit4 - hit2\n",
0125         "    np.log10(all_radius),                    # Circle radius\n",
0126         "    all_betaIn                               # Beta angle\n",
0127         "])\n",
0128         "\n",
0129         "eta_list = np.array([all_eta0])"
0130       ]
0131     },
0132     {
0133       "cell_type": "code",
0134       "execution_count": 4,
0135       "metadata": {},
0136       "outputs": [],
0137       "source": [
0138         "import torch\n",
0139         "from torch import nn\n",
0140         "from torch.optim import Adam\n",
0141         "from torch.utils.data import DataLoader, TensorDataset, random_split\n",
0142         "import numpy as np\n",
0143         "\n",
0144         "# Stack features and handle NaN/Inf as before\n",
0145         "input_features_numpy = np.stack(features, axis=-1)\n",
0146         "mask = ~np.isnan(input_features_numpy) & ~np.isinf(input_features_numpy)\n",
0147         "filtered_input_features_numpy = input_features_numpy[np.all(mask, axis=1)]\n",
0148         "t3_isFake_filtered = (np.concatenate(branches['t3_pMatched']) < 0.75)[np.all(mask, axis=1)]\n",
0149         "t3_sim_vxy_filtered = np.concatenate(branches['t3_sim_vxy'])[np.all(mask, axis=1)]\n",
0150         "\n",
0151         "# Convert to PyTorch tensor\n",
0152         "input_features_tensor = torch.tensor(filtered_input_features_numpy, dtype=torch.float32)"
0153       ]
0154     },
0155     {
0156       "cell_type": "code",
0157       "execution_count": 5,
0158       "metadata": {},
0159       "outputs": [
0160         {
0161           "name": "stdout",
0162           "output_type": "stream",
0163           "text": [
0164             "Using device: cuda\n",
0165             "Initial dataset size: 55072926\n",
0166             "Class distribution before downsampling - Fake: 49829032.0, Prompt: 4472777.0, Displaced: 771119.0\n",
0167             "Class distribution after downsampling - Fake: 9965806.0, Prompt: 4472777.0, Displaced: 771119.0\n",
0168             "Epoch [1/400], Train Loss: 0.6515, Test Loss: 0.5908\n",
0169             "Epoch [2/400], Train Loss: 0.5771, Test Loss: 0.5659\n",
0170             "Epoch [3/400], Train Loss: 0.5647, Test Loss: 0.5549\n",
0171             "Epoch [4/400], Train Loss: 0.5588, Test Loss: 0.5627\n",
0172             "Epoch [5/400], Train Loss: 0.5553, Test Loss: 0.5541\n",
0173             "Epoch [6/400], Train Loss: 0.5528, Test Loss: 0.5664\n",
0174             "Epoch [7/400], Train Loss: 0.5508, Test Loss: 0.5574\n",
0175             "Epoch [8/400], Train Loss: 0.5492, Test Loss: 0.5503\n",
0176             "Epoch [9/400], Train Loss: 0.5478, Test Loss: 0.5447\n",
0177             "Epoch [10/400], Train Loss: 0.5472, Test Loss: 0.5538\n",
0178             "Epoch [11/400], Train Loss: 0.5459, Test Loss: 0.5534\n",
0179             "Epoch [12/400], Train Loss: 0.5454, Test Loss: 0.5487\n",
0180             "Epoch [13/400], Train Loss: 0.5445, Test Loss: 0.5366\n",
0181             "Epoch [14/400], Train Loss: 0.5441, Test Loss: 0.5387\n",
0182             "Epoch [15/400], Train Loss: 0.5434, Test Loss: 0.5421\n",
0183             "Epoch [16/400], Train Loss: 0.5426, Test Loss: 0.5397\n",
0184             "Epoch [17/400], Train Loss: 0.5420, Test Loss: 0.5486\n",
0185             "Epoch [18/400], Train Loss: 0.5412, Test Loss: 0.5398\n",
0186             "Epoch [19/400], Train Loss: 0.5409, Test Loss: 0.5421\n",
0187             "Epoch [20/400], Train Loss: 0.5405, Test Loss: 0.5499\n",
0188             "Epoch [21/400], Train Loss: 0.5399, Test Loss: 0.5573\n",
0189             "Epoch [22/400], Train Loss: 0.5396, Test Loss: 0.5388\n",
0190             "Epoch [23/400], Train Loss: 0.5393, Test Loss: 0.5399\n",
0191             "Epoch [24/400], Train Loss: 0.5388, Test Loss: 0.5391\n",
0192             "Epoch [25/400], Train Loss: 0.5383, Test Loss: 0.5375\n",
0193             "Epoch [26/400], Train Loss: 0.5381, Test Loss: 0.5386\n",
0194             "Epoch [27/400], Train Loss: 0.5380, Test Loss: 0.5454\n",
0195             "Epoch [28/400], Train Loss: 0.5376, Test Loss: 0.5350\n",
0196             "Epoch [29/400], Train Loss: 0.5376, Test Loss: 0.5572\n",
0197             "Epoch [30/400], Train Loss: 0.5371, Test Loss: 0.5486\n",
0198             "Epoch [31/400], Train Loss: 0.5364, Test Loss: 0.5331\n",
0199             "Epoch [32/400], Train Loss: 0.5368, Test Loss: 0.5498\n",
0200             "Epoch [33/400], Train Loss: 0.5367, Test Loss: 0.5363\n",
0201             "Epoch [34/400], Train Loss: 0.5363, Test Loss: 0.5377\n",
0202             "Epoch [35/400], Train Loss: 0.5360, Test Loss: 0.5413\n",
0203             "Epoch [36/400], Train Loss: 0.5361, Test Loss: 0.5385\n",
0204             "Epoch [37/400], Train Loss: 0.5357, Test Loss: 0.5433\n",
0205             "Epoch [38/400], Train Loss: 0.5353, Test Loss: 0.5404\n",
0206             "Epoch [39/400], Train Loss: 0.5353, Test Loss: 0.5328\n",
0207             "Epoch [40/400], Train Loss: 0.5349, Test Loss: 0.5363\n",
0208             "Epoch [41/400], Train Loss: 0.5348, Test Loss: 0.5371\n",
0209             "Epoch [42/400], Train Loss: 0.5347, Test Loss: 0.5349\n",
0210             "Epoch [43/400], Train Loss: 0.5344, Test Loss: 0.5356\n",
0211             "Epoch [44/400], Train Loss: 0.5342, Test Loss: 0.5355\n",
0212             "Epoch [45/400], Train Loss: 0.5338, Test Loss: 0.5346\n",
0213             "Epoch [46/400], Train Loss: 0.5337, Test Loss: 0.5345\n",
0214             "Epoch [47/400], Train Loss: 0.5336, Test Loss: 0.5323\n",
0215             "Epoch [48/400], Train Loss: 0.5333, Test Loss: 0.5298\n",
0216             "Epoch [49/400], Train Loss: 0.5332, Test Loss: 0.5388\n",
0217             "Epoch [50/400], Train Loss: 0.5331, Test Loss: 0.5312\n",
0218             "Epoch [51/400], Train Loss: 0.5329, Test Loss: 0.5305\n",
0219             "Epoch [52/400], Train Loss: 0.5328, Test Loss: 0.5325\n",
0220             "Epoch [53/400], Train Loss: 0.5325, Test Loss: 0.5333\n",
0221             "Epoch [54/400], Train Loss: 0.5325, Test Loss: 0.5285\n",
0222             "Epoch [55/400], Train Loss: 0.5325, Test Loss: 0.5400\n",
0223             "Epoch [56/400], Train Loss: 0.5323, Test Loss: 0.5324\n",
0224             "Epoch [57/400], Train Loss: 0.5320, Test Loss: 0.5298\n",
0225             "Epoch [58/400], Train Loss: 0.5319, Test Loss: 0.5408\n",
0226             "Epoch [59/400], Train Loss: 0.5319, Test Loss: 0.5294\n",
0227             "Epoch [60/400], Train Loss: 0.5315, Test Loss: 0.5293\n",
0228             "Epoch [61/400], Train Loss: 0.5316, Test Loss: 0.5381\n",
0229             "Epoch [62/400], Train Loss: 0.5315, Test Loss: 0.5302\n",
0230             "Epoch [63/400], Train Loss: 0.5316, Test Loss: 0.5329\n",
0231             "Epoch [64/400], Train Loss: 0.5313, Test Loss: 0.5341\n",
0232             "Epoch [65/400], Train Loss: 0.5311, Test Loss: 0.5333\n",
0233             "Epoch [66/400], Train Loss: 0.5311, Test Loss: 0.5379\n",
0234             "Epoch [67/400], Train Loss: 0.5310, Test Loss: 0.5314\n",
0235             "Epoch [68/400], Train Loss: 0.5308, Test Loss: 0.5377\n",
0236             "Epoch [69/400], Train Loss: 0.5310, Test Loss: 0.5325\n",
0237             "Epoch [70/400], Train Loss: 0.5307, Test Loss: 0.5307\n",
0238             "Epoch [71/400], Train Loss: 0.5305, Test Loss: 0.5321\n",
0239             "Epoch [72/400], Train Loss: 0.5304, Test Loss: 0.5328\n",
0240             "Epoch [73/400], Train Loss: 0.5304, Test Loss: 0.5355\n",
0241             "Epoch [74/400], Train Loss: 0.5301, Test Loss: 0.5298\n",
0242             "Epoch [75/400], Train Loss: 0.5300, Test Loss: 0.5363\n",
0243             "Epoch [76/400], Train Loss: 0.5303, Test Loss: 0.5331\n",
0244             "Epoch [77/400], Train Loss: 0.5298, Test Loss: 0.5270\n",
0245             "Epoch [78/400], Train Loss: 0.5300, Test Loss: 0.5324\n",
0246             "Epoch [79/400], Train Loss: 0.5300, Test Loss: 0.5336\n",
0247             "Epoch [80/400], Train Loss: 0.5297, Test Loss: 0.5283\n",
0248             "Epoch [81/400], Train Loss: 0.5297, Test Loss: 0.5285\n",
0249             "Epoch [82/400], Train Loss: 0.5295, Test Loss: 0.5286\n",
0250             "Epoch [83/400], Train Loss: 0.5295, Test Loss: 0.5277\n",
0251             "Epoch [84/400], Train Loss: 0.5294, Test Loss: 0.5300\n",
0252             "Epoch [85/400], Train Loss: 0.5295, Test Loss: 0.5317\n",
0253             "Epoch [86/400], Train Loss: 0.5292, Test Loss: 0.5288\n",
0254             "Epoch [87/400], Train Loss: 0.5293, Test Loss: 0.5295\n",
0255             "Epoch [88/400], Train Loss: 0.5291, Test Loss: 0.5273\n",
0256             "Epoch [89/400], Train Loss: 0.5292, Test Loss: 0.5289\n",
0257             "Epoch [90/400], Train Loss: 0.5292, Test Loss: 0.5273\n",
0258             "Epoch [91/400], Train Loss: 0.5289, Test Loss: 0.5370\n",
0259             "Epoch [92/400], Train Loss: 0.5288, Test Loss: 0.5263\n",
0260             "Epoch [93/400], Train Loss: 0.5288, Test Loss: 0.5338\n",
0261             "Epoch [94/400], Train Loss: 0.5287, Test Loss: 0.5326\n",
0262             "Epoch [95/400], Train Loss: 0.5286, Test Loss: 0.5300\n",
0263             "Epoch [96/400], Train Loss: 0.5286, Test Loss: 0.5280\n",
0264             "Epoch [97/400], Train Loss: 0.5284, Test Loss: 0.5291\n",
0265             "Epoch [98/400], Train Loss: 0.5284, Test Loss: 0.5310\n",
0266             "Epoch [99/400], Train Loss: 0.5283, Test Loss: 0.5275\n",
0267             "Epoch [100/400], Train Loss: 0.5280, Test Loss: 0.5294\n",
0268             "Epoch [101/400], Train Loss: 0.5279, Test Loss: 0.5290\n",
0269             "Epoch [102/400], Train Loss: 0.5273, Test Loss: 0.5264\n",
0270             "Epoch [103/400], Train Loss: 0.5270, Test Loss: 0.5289\n",
0271             "Epoch [104/400], Train Loss: 0.5265, Test Loss: 0.5294\n",
0272             "Epoch [105/400], Train Loss: 0.5263, Test Loss: 0.5290\n",
0273             "Epoch [106/400], Train Loss: 0.5261, Test Loss: 0.5302\n",
0274             "Epoch [107/400], Train Loss: 0.5258, Test Loss: 0.5309\n",
0275             "Epoch [108/400], Train Loss: 0.5258, Test Loss: 0.5254\n",
0276             "Epoch [109/400], Train Loss: 0.5256, Test Loss: 0.5234\n",
0277             "Epoch [110/400], Train Loss: 0.5255, Test Loss: 0.5307\n",
0278             "Epoch [111/400], Train Loss: 0.5256, Test Loss: 0.5250\n",
0279             "Epoch [112/400], Train Loss: 0.5252, Test Loss: 0.5300\n",
0280             "Epoch [113/400], Train Loss: 0.5253, Test Loss: 0.5328\n",
0281             "Epoch [114/400], Train Loss: 0.5252, Test Loss: 0.5347\n",
0282             "Epoch [115/400], Train Loss: 0.5251, Test Loss: 0.5263\n",
0283             "Epoch [116/400], Train Loss: 0.5250, Test Loss: 0.5312\n",
0284             "Epoch [117/400], Train Loss: 0.5250, Test Loss: 0.5313\n",
0285             "Epoch [118/400], Train Loss: 0.5248, Test Loss: 0.5291\n",
0286             "Epoch [119/400], Train Loss: 0.5249, Test Loss: 0.5314\n",
0287             "Epoch [120/400], Train Loss: 0.5249, Test Loss: 0.5246\n",
0288             "Epoch [121/400], Train Loss: 0.5246, Test Loss: 0.5271\n",
0289             "Epoch [122/400], Train Loss: 0.5244, Test Loss: 0.5286\n",
0290             "Epoch [123/400], Train Loss: 0.5241, Test Loss: 0.5361\n",
0291             "Epoch [124/400], Train Loss: 0.5240, Test Loss: 0.5229\n",
0292             "Epoch [125/400], Train Loss: 0.5239, Test Loss: 0.5268\n",
0293             "Epoch [126/400], Train Loss: 0.5239, Test Loss: 0.5233\n",
0294             "Epoch [127/400], Train Loss: 0.5238, Test Loss: 0.5254\n",
0295             "Epoch [128/400], Train Loss: 0.5236, Test Loss: 0.5271\n",
0296             "Epoch [129/400], Train Loss: 0.5235, Test Loss: 0.5219\n",
0297             "Epoch [130/400], Train Loss: 0.5234, Test Loss: 0.5273\n",
0298             "Epoch [131/400], Train Loss: 0.5232, Test Loss: 0.5241\n",
0299             "Epoch [132/400], Train Loss: 0.5230, Test Loss: 0.5234\n",
0300             "Epoch [133/400], Train Loss: 0.5229, Test Loss: 0.5232\n",
0301             "Epoch [134/400], Train Loss: 0.5229, Test Loss: 0.5288\n",
0302             "Epoch [135/400], Train Loss: 0.5229, Test Loss: 0.5261\n",
0303             "Epoch [136/400], Train Loss: 0.5230, Test Loss: 0.5271\n",
0304             "Epoch [137/400], Train Loss: 0.5227, Test Loss: 0.5287\n",
0305             "Epoch [138/400], Train Loss: 0.5228, Test Loss: 0.5216\n",
0306             "Epoch [139/400], Train Loss: 0.5227, Test Loss: 0.5263\n",
0307             "Epoch [140/400], Train Loss: 0.5224, Test Loss: 0.5274\n",
0308             "Epoch [141/400], Train Loss: 0.5225, Test Loss: 0.5234\n",
0309             "Epoch [142/400], Train Loss: 0.5226, Test Loss: 0.5251\n",
0310             "Epoch [143/400], Train Loss: 0.5221, Test Loss: 0.5224\n",
0311             "Epoch [144/400], Train Loss: 0.5223, Test Loss: 0.5222\n",
0312             "Epoch [145/400], Train Loss: 0.5224, Test Loss: 0.5275\n",
0313             "Epoch [146/400], Train Loss: 0.5223, Test Loss: 0.5203\n",
0314             "Epoch [147/400], Train Loss: 0.5223, Test Loss: 0.5218\n",
0315             "Epoch [148/400], Train Loss: 0.5222, Test Loss: 0.5256\n",
0316             "Epoch [149/400], Train Loss: 0.5221, Test Loss: 0.5227\n",
0317             "Epoch [150/400], Train Loss: 0.5219, Test Loss: 0.5210\n",
0318             "Epoch [151/400], Train Loss: 0.5221, Test Loss: 0.5239\n",
0319             "Epoch [152/400], Train Loss: 0.5221, Test Loss: 0.5218\n",
0320             "Epoch [153/400], Train Loss: 0.5219, Test Loss: 0.5305\n",
0321             "Epoch [154/400], Train Loss: 0.5219, Test Loss: 0.5248\n",
0322             "Epoch [155/400], Train Loss: 0.5218, Test Loss: 0.5247\n",
0323             "Epoch [156/400], Train Loss: 0.5218, Test Loss: 0.5222\n",
0324             "Epoch [157/400], Train Loss: 0.5216, Test Loss: 0.5332\n",
0325             "Epoch [158/400], Train Loss: 0.5217, Test Loss: 0.5230\n",
0326             "Epoch [159/400], Train Loss: 0.5217, Test Loss: 0.5237\n",
0327             "Epoch [160/400], Train Loss: 0.5216, Test Loss: 0.5205\n",
0328             "Epoch [161/400], Train Loss: 0.5215, Test Loss: 0.5208\n",
0329             "Epoch [162/400], Train Loss: 0.5216, Test Loss: 0.5242\n",
0330             "Epoch [163/400], Train Loss: 0.5216, Test Loss: 0.5254\n",
0331             "Epoch [164/400], Train Loss: 0.5214, Test Loss: 0.5229\n",
0332             "Epoch [165/400], Train Loss: 0.5214, Test Loss: 0.5260\n",
0333             "Epoch [166/400], Train Loss: 0.5213, Test Loss: 0.5193\n",
0334             "Epoch [167/400], Train Loss: 0.5212, Test Loss: 0.5225\n",
0335             "Epoch [168/400], Train Loss: 0.5211, Test Loss: 0.5240\n",
0336             "Epoch [169/400], Train Loss: 0.5213, Test Loss: 0.5220\n",
0337             "Epoch [170/400], Train Loss: 0.5213, Test Loss: 0.5276\n",
0338             "Epoch [171/400], Train Loss: 0.5211, Test Loss: 0.5203\n",
0339             "Epoch [172/400], Train Loss: 0.5214, Test Loss: 0.5202\n",
0340             "Epoch [173/400], Train Loss: 0.5210, Test Loss: 0.5213\n",
0341             "Epoch [174/400], Train Loss: 0.5212, Test Loss: 0.5215\n",
0342             "Epoch [175/400], Train Loss: 0.5211, Test Loss: 0.5242\n",
0343             "Epoch [176/400], Train Loss: 0.5210, Test Loss: 0.5217\n",
0344             "Epoch [177/400], Train Loss: 0.5209, Test Loss: 0.5231\n",
0345             "Epoch [178/400], Train Loss: 0.5210, Test Loss: 0.5225\n",
0346             "Epoch [179/400], Train Loss: 0.5210, Test Loss: 0.5229\n",
0347             "Epoch [180/400], Train Loss: 0.5208, Test Loss: 0.5235\n",
0348             "Epoch [181/400], Train Loss: 0.5207, Test Loss: 0.5244\n",
0349             "Epoch [182/400], Train Loss: 0.5208, Test Loss: 0.5224\n",
0350             "Epoch [183/400], Train Loss: 0.5209, Test Loss: 0.5264\n",
0351             "Epoch [184/400], Train Loss: 0.5207, Test Loss: 0.5220\n",
0352             "Epoch [185/400], Train Loss: 0.5206, Test Loss: 0.5202\n",
0353             "Epoch [186/400], Train Loss: 0.5208, Test Loss: 0.5187\n",
0354             "Epoch [187/400], Train Loss: 0.5206, Test Loss: 0.5270\n",
0355             "Epoch [188/400], Train Loss: 0.5207, Test Loss: 0.5196\n",
0356             "Epoch [189/400], Train Loss: 0.5205, Test Loss: 0.5270\n",
0357             "Epoch [190/400], Train Loss: 0.5207, Test Loss: 0.5241\n",
0358             "Epoch [191/400], Train Loss: 0.5206, Test Loss: 0.5226\n",
0359             "Epoch [192/400], Train Loss: 0.5205, Test Loss: 0.5289\n",
0360             "Epoch [193/400], Train Loss: 0.5205, Test Loss: 0.5204\n",
0361             "Epoch [194/400], Train Loss: 0.5204, Test Loss: 0.5215\n",
0362             "Epoch [195/400], Train Loss: 0.5205, Test Loss: 0.5205\n",
0363             "Epoch [196/400], Train Loss: 0.5204, Test Loss: 0.5236\n",
0364             "Epoch [197/400], Train Loss: 0.5205, Test Loss: 0.5209\n",
0365             "Epoch [198/400], Train Loss: 0.5202, Test Loss: 0.5225\n",
0366             "Epoch [199/400], Train Loss: 0.5204, Test Loss: 0.5219\n",
0367             "Epoch [200/400], Train Loss: 0.5203, Test Loss: 0.5217\n",
0368             "Epoch [201/400], Train Loss: 0.5204, Test Loss: 0.5237\n",
0369             "Epoch [202/400], Train Loss: 0.5201, Test Loss: 0.5186\n",
0370             "Epoch [203/400], Train Loss: 0.5203, Test Loss: 0.5228\n",
0371             "Epoch [204/400], Train Loss: 0.5202, Test Loss: 0.5213\n",
0372             "Epoch [205/400], Train Loss: 0.5200, Test Loss: 0.5197\n",
0373             "Epoch [206/400], Train Loss: 0.5202, Test Loss: 0.5209\n",
0374             "Epoch [207/400], Train Loss: 0.5200, Test Loss: 0.5250\n",
0375             "Epoch [208/400], Train Loss: 0.5203, Test Loss: 0.5183\n",
0376             "Epoch [209/400], Train Loss: 0.5201, Test Loss: 0.5181\n",
0377             "Epoch [210/400], Train Loss: 0.5200, Test Loss: 0.5235\n",
0378             "Epoch [211/400], Train Loss: 0.5201, Test Loss: 0.5209\n",
0379             "Epoch [212/400], Train Loss: 0.5200, Test Loss: 0.5203\n",
0380             "Epoch [213/400], Train Loss: 0.5202, Test Loss: 0.5235\n",
0381             "Epoch [214/400], Train Loss: 0.5201, Test Loss: 0.5184\n",
0382             "Epoch [215/400], Train Loss: 0.5199, Test Loss: 0.5275\n",
0383             "Epoch [216/400], Train Loss: 0.5199, Test Loss: 0.5200\n",
0384             "Epoch [217/400], Train Loss: 0.5199, Test Loss: 0.5216\n",
0385             "Epoch [218/400], Train Loss: 0.5199, Test Loss: 0.5230\n",
0386             "Epoch [219/400], Train Loss: 0.5200, Test Loss: 0.5193\n",
0387             "Epoch [220/400], Train Loss: 0.5199, Test Loss: 0.5217\n",
0388             "Epoch [221/400], Train Loss: 0.5200, Test Loss: 0.5234\n",
0389             "Epoch [222/400], Train Loss: 0.5197, Test Loss: 0.5226\n",
0390             "Epoch [223/400], Train Loss: 0.5198, Test Loss: 0.5242\n",
0391             "Epoch [224/400], Train Loss: 0.5198, Test Loss: 0.5226\n",
0392             "Epoch [225/400], Train Loss: 0.5199, Test Loss: 0.5172\n",
0393             "Epoch [226/400], Train Loss: 0.5197, Test Loss: 0.5206\n",
0394             "Epoch [227/400], Train Loss: 0.5197, Test Loss: 0.5211\n",
0395             "Epoch [228/400], Train Loss: 0.5197, Test Loss: 0.5199\n",
0396             "Epoch [229/400], Train Loss: 0.5197, Test Loss: 0.5194\n",
0397             "Epoch [230/400], Train Loss: 0.5197, Test Loss: 0.5212\n",
0398             "Epoch [231/400], Train Loss: 0.5197, Test Loss: 0.5235\n",
0399             "Epoch [232/400], Train Loss: 0.5199, Test Loss: 0.5180\n",
0400             "Epoch [233/400], Train Loss: 0.5197, Test Loss: 0.5186\n",
0401             "Epoch [234/400], Train Loss: 0.5198, Test Loss: 0.5192\n",
0402             "Epoch [235/400], Train Loss: 0.5197, Test Loss: 0.5232\n",
0403             "Epoch [236/400], Train Loss: 0.5195, Test Loss: 0.5267\n",
0404             "Epoch [237/400], Train Loss: 0.5198, Test Loss: 0.5203\n",
0405             "Epoch [238/400], Train Loss: 0.5196, Test Loss: 0.5195\n",
0406             "Epoch [239/400], Train Loss: 0.5197, Test Loss: 0.5210\n",
0407             "Epoch [240/400], Train Loss: 0.5196, Test Loss: 0.5273\n",
0408             "Epoch [241/400], Train Loss: 0.5195, Test Loss: 0.5247\n",
0409             "Epoch [242/400], Train Loss: 0.5197, Test Loss: 0.5184\n",
0410             "Epoch [243/400], Train Loss: 0.5196, Test Loss: 0.5188\n",
0411             "Epoch [244/400], Train Loss: 0.5198, Test Loss: 0.5201\n",
0412             "Epoch [245/400], Train Loss: 0.5195, Test Loss: 0.5242\n",
0413             "Epoch [246/400], Train Loss: 0.5196, Test Loss: 0.5204\n",
0414             "Epoch [247/400], Train Loss: 0.5196, Test Loss: 0.5232\n",
0415             "Epoch [248/400], Train Loss: 0.5194, Test Loss: 0.5268\n",
0416             "Epoch [249/400], Train Loss: 0.5196, Test Loss: 0.5205\n",
0417             "Epoch [250/400], Train Loss: 0.5195, Test Loss: 0.5255\n",
0418             "Epoch [251/400], Train Loss: 0.5195, Test Loss: 0.5211\n",
0419             "Epoch [252/400], Train Loss: 0.5195, Test Loss: 0.5200\n",
0420             "Epoch [253/400], Train Loss: 0.5196, Test Loss: 0.5217\n",
0421             "Epoch [254/400], Train Loss: 0.5196, Test Loss: 0.5208\n",
0422             "Epoch [255/400], Train Loss: 0.5194, Test Loss: 0.5209\n",
0423             "Epoch [256/400], Train Loss: 0.5195, Test Loss: 0.5252\n",
0424             "Epoch [257/400], Train Loss: 0.5194, Test Loss: 0.5209\n",
0425             "Epoch [258/400], Train Loss: 0.5195, Test Loss: 0.5247\n",
0426             "Epoch [259/400], Train Loss: 0.5196, Test Loss: 0.5201\n",
0427             "Epoch [260/400], Train Loss: 0.5192, Test Loss: 0.5191\n",
0428             "Epoch [261/400], Train Loss: 0.5194, Test Loss: 0.5202\n",
0429             "Epoch [262/400], Train Loss: 0.5194, Test Loss: 0.5209\n",
0430             "Epoch [263/400], Train Loss: 0.5193, Test Loss: 0.5258\n",
0431             "Epoch [264/400], Train Loss: 0.5194, Test Loss: 0.5226\n",
0432             "Epoch [265/400], Train Loss: 0.5195, Test Loss: 0.5263\n",
0433             "Epoch [266/400], Train Loss: 0.5195, Test Loss: 0.5223\n",
0434             "Epoch [267/400], Train Loss: 0.5193, Test Loss: 0.5236\n",
0435             "Epoch [268/400], Train Loss: 0.5192, Test Loss: 0.5274\n",
0436             "Epoch [269/400], Train Loss: 0.5194, Test Loss: 0.5193\n",
0437             "Epoch [270/400], Train Loss: 0.5191, Test Loss: 0.5189\n",
0438             "Epoch [271/400], Train Loss: 0.5193, Test Loss: 0.5257\n",
0439             "Epoch [272/400], Train Loss: 0.5193, Test Loss: 0.5191\n",
0440             "Epoch [273/400], Train Loss: 0.5194, Test Loss: 0.5192\n",
0441             "Epoch [274/400], Train Loss: 0.5193, Test Loss: 0.5215\n",
0442             "Epoch [275/400], Train Loss: 0.5192, Test Loss: 0.5199\n",
0443             "Epoch [276/400], Train Loss: 0.5193, Test Loss: 0.5231\n",
0444             "Epoch [277/400], Train Loss: 0.5192, Test Loss: 0.5210\n",
0445             "Epoch [278/400], Train Loss: 0.5192, Test Loss: 0.5203\n",
0446             "Epoch [279/400], Train Loss: 0.5194, Test Loss: 0.5225\n",
0447             "Epoch [280/400], Train Loss: 0.5193, Test Loss: 0.5242\n",
0448             "Epoch [281/400], Train Loss: 0.5192, Test Loss: 0.5270\n",
0449             "Epoch [282/400], Train Loss: 0.5192, Test Loss: 0.5226\n",
0450             "Epoch [283/400], Train Loss: 0.5192, Test Loss: 0.5221\n",
0451             "Epoch [284/400], Train Loss: 0.5193, Test Loss: 0.5171\n",
0452             "Epoch [285/400], Train Loss: 0.5191, Test Loss: 0.5211\n",
0453             "Epoch [286/400], Train Loss: 0.5191, Test Loss: 0.5178\n",
0454             "Epoch [287/400], Train Loss: 0.5190, Test Loss: 0.5173\n",
0455             "Epoch [288/400], Train Loss: 0.5192, Test Loss: 0.5277\n",
0456             "Epoch [289/400], Train Loss: 0.5190, Test Loss: 0.5196\n",
0457             "Epoch [290/400], Train Loss: 0.5192, Test Loss: 0.5200\n",
0458             "Epoch [291/400], Train Loss: 0.5190, Test Loss: 0.5186\n",
0459             "Epoch [292/400], Train Loss: 0.5192, Test Loss: 0.5211\n",
0460             "Epoch [293/400], Train Loss: 0.5192, Test Loss: 0.5249\n",
0461             "Epoch [294/400], Train Loss: 0.5191, Test Loss: 0.5196\n",
0462             "Epoch [295/400], Train Loss: 0.5191, Test Loss: 0.5215\n",
0463             "Epoch [296/400], Train Loss: 0.5192, Test Loss: 0.5223\n",
0464             "Epoch [297/400], Train Loss: 0.5192, Test Loss: 0.5233\n",
0465             "Epoch [298/400], Train Loss: 0.5191, Test Loss: 0.5223\n",
0466             "Epoch [299/400], Train Loss: 0.5189, Test Loss: 0.5212\n",
0467             "Epoch [300/400], Train Loss: 0.5189, Test Loss: 0.5199\n",
0468             "Epoch [301/400], Train Loss: 0.5190, Test Loss: 0.5197\n",
0469             "Epoch [302/400], Train Loss: 0.5190, Test Loss: 0.5289\n",
0470             "Epoch [303/400], Train Loss: 0.5189, Test Loss: 0.5220\n",
0471             "Epoch [304/400], Train Loss: 0.5190, Test Loss: 0.5296\n",
0472             "Epoch [305/400], Train Loss: 0.5189, Test Loss: 0.5185\n",
0473             "Epoch [306/400], Train Loss: 0.5190, Test Loss: 0.5191\n",
0474             "Epoch [307/400], Train Loss: 0.5189, Test Loss: 0.5196\n",
0475             "Epoch [308/400], Train Loss: 0.5190, Test Loss: 0.5204\n",
0476             "Epoch [309/400], Train Loss: 0.5188, Test Loss: 0.5200\n",
0477             "Epoch [310/400], Train Loss: 0.5190, Test Loss: 0.5254\n",
0478             "Epoch [311/400], Train Loss: 0.5190, Test Loss: 0.5213\n",
0479             "Epoch [312/400], Train Loss: 0.5189, Test Loss: 0.5206\n",
0480             "Epoch [313/400], Train Loss: 0.5188, Test Loss: 0.5239\n",
0481             "Epoch [314/400], Train Loss: 0.5189, Test Loss: 0.5198\n",
0482             "Epoch [315/400], Train Loss: 0.5190, Test Loss: 0.5172\n",
0483             "Epoch [316/400], Train Loss: 0.5187, Test Loss: 0.5199\n",
0484             "Epoch [317/400], Train Loss: 0.5189, Test Loss: 0.5190\n",
0485             "Epoch [318/400], Train Loss: 0.5188, Test Loss: 0.5195\n",
0486             "Epoch [319/400], Train Loss: 0.5187, Test Loss: 0.5191\n",
0487             "Epoch [320/400], Train Loss: 0.5188, Test Loss: 0.5213\n",
0488             "Epoch [321/400], Train Loss: 0.5190, Test Loss: 0.5191\n",
0489             "Epoch [322/400], Train Loss: 0.5188, Test Loss: 0.5250\n",
0490             "Epoch [323/400], Train Loss: 0.5187, Test Loss: 0.5187\n",
0491             "Epoch [324/400], Train Loss: 0.5188, Test Loss: 0.5219\n",
0492             "Epoch [325/400], Train Loss: 0.5186, Test Loss: 0.5200\n",
0493             "Epoch [326/400], Train Loss: 0.5187, Test Loss: 0.5229\n",
0494             "Epoch [327/400], Train Loss: 0.5187, Test Loss: 0.5191\n",
0495             "Epoch [328/400], Train Loss: 0.5187, Test Loss: 0.5193\n",
0496             "Epoch [329/400], Train Loss: 0.5188, Test Loss: 0.5193\n",
0497             "Epoch [330/400], Train Loss: 0.5187, Test Loss: 0.5209\n",
0498             "Epoch [331/400], Train Loss: 0.5187, Test Loss: 0.5222\n",
0499             "Epoch [332/400], Train Loss: 0.5187, Test Loss: 0.5207\n",
0500             "Epoch [333/400], Train Loss: 0.5187, Test Loss: 0.5180\n",
0501             "Epoch [334/400], Train Loss: 0.5186, Test Loss: 0.5229\n",
0502             "Epoch [335/400], Train Loss: 0.5186, Test Loss: 0.5184\n",
0503             "Epoch [336/400], Train Loss: 0.5187, Test Loss: 0.5188\n",
0504             "Epoch [337/400], Train Loss: 0.5187, Test Loss: 0.5204\n",
0505             "Epoch [338/400], Train Loss: 0.5185, Test Loss: 0.5292\n",
0506             "Epoch [339/400], Train Loss: 0.5186, Test Loss: 0.5214\n",
0507             "Epoch [340/400], Train Loss: 0.5187, Test Loss: 0.5210\n",
0508             "Epoch [341/400], Train Loss: 0.5187, Test Loss: 0.5220\n",
0509             "Epoch [342/400], Train Loss: 0.5186, Test Loss: 0.5202\n",
0510             "Epoch [343/400], Train Loss: 0.5185, Test Loss: 0.5311\n",
0511             "Epoch [344/400], Train Loss: 0.5186, Test Loss: 0.5209\n",
0512             "Epoch [345/400], Train Loss: 0.5187, Test Loss: 0.5205\n",
0513             "Epoch [346/400], Train Loss: 0.5185, Test Loss: 0.5178\n",
0514             "Epoch [347/400], Train Loss: 0.5185, Test Loss: 0.5186\n",
0515             "Epoch [348/400], Train Loss: 0.5186, Test Loss: 0.5228\n",
0516             "Epoch [349/400], Train Loss: 0.5184, Test Loss: 0.5222\n",
0517             "Epoch [350/400], Train Loss: 0.5186, Test Loss: 0.5246\n",
0518             "Epoch [351/400], Train Loss: 0.5185, Test Loss: 0.5195\n",
0519             "Epoch [352/400], Train Loss: 0.5185, Test Loss: 0.5189\n",
0520             "Epoch [353/400], Train Loss: 0.5184, Test Loss: 0.5212\n",
0521             "Epoch [354/400], Train Loss: 0.5186, Test Loss: 0.5267\n",
0522             "Epoch [355/400], Train Loss: 0.5184, Test Loss: 0.5227\n",
0523             "Epoch [356/400], Train Loss: 0.5183, Test Loss: 0.5213\n",
0524             "Epoch [357/400], Train Loss: 0.5183, Test Loss: 0.5214\n",
0525             "Epoch [358/400], Train Loss: 0.5183, Test Loss: 0.5228\n",
0526             "Epoch [359/400], Train Loss: 0.5185, Test Loss: 0.5195\n",
0527             "Epoch [360/400], Train Loss: 0.5183, Test Loss: 0.5234\n",
0528             "Epoch [361/400], Train Loss: 0.5185, Test Loss: 0.5198\n",
0529             "Epoch [362/400], Train Loss: 0.5183, Test Loss: 0.5193\n",
0530             "Epoch [363/400], Train Loss: 0.5185, Test Loss: 0.5208\n",
0531             "Epoch [364/400], Train Loss: 0.5185, Test Loss: 0.5217\n",
0532             "Epoch [365/400], Train Loss: 0.5184, Test Loss: 0.5258\n",
0533             "Epoch [366/400], Train Loss: 0.5185, Test Loss: 0.5199\n",
0534             "Epoch [367/400], Train Loss: 0.5183, Test Loss: 0.5197\n",
0535             "Epoch [368/400], Train Loss: 0.5182, Test Loss: 0.5189\n",
0536             "Epoch [369/400], Train Loss: 0.5184, Test Loss: 0.5191\n",
0537             "Epoch [370/400], Train Loss: 0.5183, Test Loss: 0.5198\n",
0538             "Epoch [371/400], Train Loss: 0.5182, Test Loss: 0.5230\n",
0539             "Epoch [372/400], Train Loss: 0.5183, Test Loss: 0.5178\n",
0540             "Epoch [373/400], Train Loss: 0.5183, Test Loss: 0.5184\n",
0541             "Epoch [374/400], Train Loss: 0.5182, Test Loss: 0.5187\n",
0542             "Epoch [375/400], Train Loss: 0.5183, Test Loss: 0.5224\n",
0543             "Epoch [376/400], Train Loss: 0.5184, Test Loss: 0.5203\n",
0544             "Epoch [377/400], Train Loss: 0.5183, Test Loss: 0.5197\n",
0545             "Epoch [378/400], Train Loss: 0.5182, Test Loss: 0.5196\n",
0546             "Epoch [379/400], Train Loss: 0.5181, Test Loss: 0.5200\n",
0547             "Epoch [380/400], Train Loss: 0.5182, Test Loss: 0.5264\n",
0548             "Epoch [381/400], Train Loss: 0.5181, Test Loss: 0.5202\n",
0549             "Epoch [382/400], Train Loss: 0.5182, Test Loss: 0.5219\n",
0550             "Epoch [383/400], Train Loss: 0.5183, Test Loss: 0.5219\n",
0551             "Epoch [384/400], Train Loss: 0.5183, Test Loss: 0.5196\n",
0552             "Epoch [385/400], Train Loss: 0.5181, Test Loss: 0.5192\n",
0553             "Epoch [386/400], Train Loss: 0.5180, Test Loss: 0.5206\n",
0554             "Epoch [387/400], Train Loss: 0.5181, Test Loss: 0.5264\n",
0555             "Epoch [388/400], Train Loss: 0.5181, Test Loss: 0.5191\n",
0556             "Epoch [389/400], Train Loss: 0.5183, Test Loss: 0.5240\n",
0557             "Epoch [390/400], Train Loss: 0.5182, Test Loss: 0.5196\n",
0558             "Epoch [391/400], Train Loss: 0.5182, Test Loss: 0.5176\n",
0559             "Epoch [392/400], Train Loss: 0.5181, Test Loss: 0.5219\n",
0560             "Epoch [393/400], Train Loss: 0.5181, Test Loss: 0.5198\n",
0561             "Epoch [394/400], Train Loss: 0.5180, Test Loss: 0.5251\n",
0562             "Epoch [395/400], Train Loss: 0.5181, Test Loss: 0.5182\n",
0563             "Epoch [396/400], Train Loss: 0.5182, Test Loss: 0.5212\n",
0564             "Epoch [397/400], Train Loss: 0.5180, Test Loss: 0.5217\n",
0565             "Epoch [398/400], Train Loss: 0.5181, Test Loss: 0.5213\n",
0566             "Epoch [399/400], Train Loss: 0.5182, Test Loss: 0.5180\n",
0567             "Epoch [400/400], Train Loss: 0.5180, Test Loss: 0.5205\n"
0568           ]
0569         }
0570       ],
0571       "source": [
0572         "import torch\n",
0573         "import torch.nn as nn\n",
0574         "from torch.utils.data import TensorDataset, random_split, DataLoader\n",
0575         "from torch.optim import Adam\n",
0576         "\n",
0577         "# Set device\n",
0578         "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
0579         "print(f\"Using device: {device}\")\n",
0580         "\n",
0581         "# Create multi-class labels\n",
0582         "def create_multiclass_labels(t3_isFake, t3_sim_vxy, displacement_threshold=0.1):\n",
0583         "    num_samples = len(t3_isFake)\n",
0584         "    labels = torch.zeros((num_samples, 3))\n",
0585         "    \n",
0586         "    # Fake tracks (class 0)\n",
0587         "    fake_mask = t3_isFake\n",
0588         "    labels[fake_mask, 0] = 1\n",
0589         "    \n",
0590         "    # Real tracks\n",
0591         "    real_mask = ~fake_mask\n",
0592         "    \n",
0593         "    # Split real tracks into prompt (class 1) and displaced (class 2)\n",
0594         "    prompt_mask = (t3_sim_vxy <= displacement_threshold) & real_mask\n",
0595         "    displaced_mask = (t3_sim_vxy > displacement_threshold) & real_mask\n",
0596         "    \n",
0597         "    labels[prompt_mask, 1] = 1\n",
0598         "    labels[displaced_mask, 2] = 1\n",
0599         "    \n",
0600         "    return labels\n",
0601         "\n",
0602         "# Create labels tensor\n",
0603         "labels_tensor = create_multiclass_labels(\n",
0604         "    t3_isFake_filtered,\n",
0605         "    t3_sim_vxy_filtered\n",
0606         ")\n",
0607         "\n",
0608         "# Neural network for multi-class classification\n",
0609         "class MultiClassNeuralNetwork(nn.Module):\n",
0610         "    def __init__(self):\n",
0611         "        super(MultiClassNeuralNetwork, self).__init__()\n",
0612         "        self.layer1 = nn.Linear(input_features_numpy.shape[1], 32)\n",
0613         "        self.layer2 = nn.Linear(32, 32)\n",
0614         "        self.output_layer = nn.Linear(32, 3)\n",
0615         "        \n",
0616         "    def forward(self, x):\n",
0617         "        x = self.layer1(x)\n",
0618         "        x = nn.ReLU()(x)\n",
0619         "        x = self.layer2(x)\n",
0620         "        x = nn.ReLU()(x)\n",
0621         "        x = self.output_layer(x)\n",
0622         "        return nn.functional.softmax(x, dim=1)\n",
0623         "\n",
0624         "# Weighted loss function for multi-class\n",
0625         "class WeightedCrossEntropyLoss(nn.Module):\n",
0626         "    def __init__(self):\n",
0627         "        super(WeightedCrossEntropyLoss, self).__init__()\n",
0628         "        \n",
0629         "    def forward(self, outputs, targets, weights):\n",
0630         "        eps = 1e-7\n",
0631         "        log_probs = torch.log(outputs + eps)\n",
0632         "        losses = -weights * torch.sum(targets * log_probs, dim=1)\n",
0633         "        return losses.mean()\n",
0634         "\n",
0635         "# Calculate class weights (each sample gets a weight to equalize class contributions)\n",
0636         "def calculate_class_weights(labels):\n",
0637         "    class_counts = torch.sum(labels, dim=0)\n",
0638         "    total_samples = len(labels)\n",
0639         "    class_weights = total_samples / (3 * class_counts)  # Normalize across 3 classes\n",
0640         "    \n",
0641         "    sample_weights = torch.zeros(len(labels))\n",
0642         "    for i in range(3):\n",
0643         "        sample_weights[labels[:, i] == 1] = class_weights[i]\n",
0644         "    \n",
0645         "    return sample_weights\n",
0646         "\n",
0647         "# Print initial dataset size\n",
0648         "print(f\"Initial dataset size: {len(labels_tensor)}\")\n",
0649         "\n",
0650         "# Remove rows with NaN and update everything accordingly\n",
0651         "nan_mask = torch.isnan(input_features_tensor).any(dim=1)\n",
0652         "filtered_inputs = input_features_tensor[~nan_mask]\n",
0653         "filtered_labels = labels_tensor[~nan_mask]\n",
0654         "\n",
0655         "# Print class distribution before downsampling\n",
0656         "class_counts_before = torch.sum(filtered_labels, dim=0)\n",
0657         "print(f\"Class distribution before downsampling - Fake: {class_counts_before[0]}, Prompt: {class_counts_before[1]}, Displaced: {class_counts_before[2]}\")\n",
0658         "\n",
0659         "# Option to downsample each class\n",
0660         "downsample_classes = True  # Set to False to disable downsampling\n",
0661         "if downsample_classes:\n",
0662         "    # Define downsampling ratios for each class:\n",
0663         "    # For example, downsample fakes (class 0) to 20% and keep prompt (class 1) and displaced (class 2) at 100%\n",
0664         "    downsample_ratios = {0: 0.2, 1: 1.0, 2: 1.0}\n",
0665         "    indices_list = []\n",
0666         "    for cls in range(3):\n",
0667         "        # Find indices for the current class\n",
0668         "        cls_mask = (filtered_labels[:, cls] == 1)\n",
0669         "        cls_indices = torch.nonzero(cls_mask).squeeze()\n",
0670         "        ratio = downsample_ratios.get(cls, 1.0)\n",
0671         "        num_cls = cls_indices.numel()\n",
0672         "        num_to_sample = int(num_cls * ratio)\n",
0673         "        # Ensure at least one sample is kept if available\n",
0674         "        if num_to_sample < 1 and num_cls > 0:\n",
0675         "            num_to_sample = 1\n",
0676         "        # Shuffle and select the desired number of samples\n",
0677         "        cls_indices_shuffled = cls_indices[torch.randperm(num_cls)]\n",
0678         "        sampled_cls_indices = cls_indices_shuffled[:num_to_sample]\n",
0679         "        indices_list.append(sampled_cls_indices)\n",
0680         "    \n",
0681         "    # Combine the indices from all classes\n",
0682         "    selected_indices = torch.cat(indices_list)\n",
0683         "    filtered_inputs = filtered_inputs[selected_indices]\n",
0684         "    filtered_labels = filtered_labels[selected_indices]\n",
0685         "\n",
0686         "# Print class distribution after downsampling\n",
0687         "class_counts_after = torch.sum(filtered_labels, dim=0)\n",
0688         "print(f\"Class distribution after downsampling - Fake: {class_counts_after[0]}, Prompt: {class_counts_after[1]}, Displaced: {class_counts_after[2]}\")\n",
0689         "\n",
0690         "# Recalculate sample weights after downsampling (equal weighting per class based on new counts)\n",
0691         "sample_weights = calculate_class_weights(filtered_labels)\n",
0692         "filtered_weights = sample_weights\n",
0693         "\n",
0694         "# Create dataset with weights\n",
0695         "dataset = TensorDataset(filtered_inputs, filtered_labels, filtered_weights)\n",
0696         "\n",
0697         "# Split into train and test sets\n",
0698         "train_size = int(0.8 * len(dataset))\n",
0699         "test_size = len(dataset) - train_size\n",
0700         "train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n",
0701         "\n",
0702         "# Create data loaders\n",
0703         "train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=10, pin_memory=True)\n",
0704         "test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=10, pin_memory=True)\n",
0705         "\n",
0706         "# Initialize model and optimizer\n",
0707         "model = MultiClassNeuralNetwork().to(device)\n",
0708         "loss_function = WeightedCrossEntropyLoss()\n",
0709         "optimizer = Adam(model.parameters(), lr=0.0025)\n",
0710         "\n",
0711         "def evaluate_loss(loader):\n",
0712         "    model.eval()\n",
0713         "    total_loss = 0\n",
0714         "    num_batches = 0\n",
0715         "    with torch.no_grad():\n",
0716         "        for inputs, targets, weights in loader:\n",
0717         "            inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)\n",
0718         "            outputs = model(inputs)\n",
0719         "            loss = loss_function(outputs, targets, weights)\n",
0720         "            total_loss += loss.item()\n",
0721         "            num_batches += 1\n",
0722         "    return total_loss / num_batches\n",
0723         "\n",
0724         "# Training loop\n",
0725         "num_epochs = 400\n",
0726         "train_loss_log = []\n",
0727         "test_loss_log = []\n",
0728         "\n",
0729         "for epoch in range(num_epochs):\n",
0730         "    model.train()\n",
0731         "    epoch_loss = 0\n",
0732         "    num_batches = 0\n",
0733         "    \n",
0734         "    for inputs, targets, weights in train_loader:\n",
0735         "        inputs, targets, weights = inputs.to(device), targets.to(device), weights.to(device)\n",
0736         "        \n",
0737         "        # Forward pass\n",
0738         "        outputs = model(inputs)\n",
0739         "        loss = loss_function(outputs, targets, weights)\n",
0740         "        epoch_loss += loss.item()\n",
0741         "        num_batches += 1\n",
0742         "        \n",
0743         "        # Backward and optimize\n",
0744         "        optimizer.zero_grad()\n",
0745         "        loss.backward()\n",
0746         "        optimizer.step()\n",
0747         "    \n",
0748         "    # Calculate average losses\n",
0749         "    train_loss = epoch_loss / num_batches\n",
0750         "    test_loss = evaluate_loss(test_loader)\n",
0751         "    \n",
0752         "    train_loss_log.append(train_loss)\n",
0753         "    test_loss_log.append(test_loss)\n",
0754         "    \n",
0755         "    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')"
0756       ]
0757     },
0758     {
0759       "cell_type": "code",
0760       "execution_count": 6,
0761       "metadata": {},
0762       "outputs": [
0763         {
0764           "name": "stdout",
0765           "output_type": "stream",
0766           "text": [
0767             "Baseline accuracy: 0.8611\n",
0768             "\n",
0769             "Feature importances:\n",
0770             "Feature 0 importance: 0.0541\n",
0771             "Feature 2 importance: 0.0480\n",
0772             "Feature 5 importance: 0.0434\n",
0773             "Feature 7 importance: 0.0242\n",
0774             "Feature 6 importance: 0.0223\n",
0775             "Feature 3 importance: 0.0206\n",
0776             "Feature 11 importance: 0.0167\n",
0777             "Feature 10 importance: 0.0148\n",
0778             "Feature 13 importance: 0.0140\n",
0779             "Feature 12 importance: 0.0128\n",
0780             "Feature 9 importance: 0.0114\n",
0781             "Feature 8 importance: 0.0046\n",
0782             "Feature 4 importance: 0.0016\n",
0783             "Feature 1 importance: 0.0000\n"
0784           ]
0785         }
0786       ],
0787       "source": [
0788         "import torch\n",
0789         "import numpy as np\n",
0790         "from sklearn.metrics import accuracy_score\n",
0791         "\n",
0792         "# Convert tensors to numpy for simplicity if you want to manipulate them outside of PyTorch\n",
0793         "input_features_np = input_features_tensor.numpy()\n",
0794         "labels_np = torch.argmax(labels_tensor, dim=1).numpy()  # Convert one-hot to class indices\n",
0795         "\n",
0796         "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
0797         "\n",
0798         "def model_accuracy(features, labels, model):\n",
0799         "    \"\"\"\n",
0800         "    Compute accuracy for a multi-class classification model\n",
0801         "    that outputs probabilities of size [batch_size, num_classes].\n",
0802         "    \"\"\"\n",
0803         "    model.eval()  # Set the model to evaluation mode\n",
0804         "    \n",
0805         "    # Move the features and labels to the correct device\n",
0806         "    inputs = features.to(device)\n",
0807         "    labels = labels.to(device)\n",
0808         "    \n",
0809         "    with torch.no_grad():\n",
0810         "        outputs = model(inputs)  # shape: [batch_size, num_classes]\n",
0811         "        # For multi-class, the predicted class is argmax of the probabilities\n",
0812         "        predicted = torch.argmax(outputs, dim=1)\n",
0813         "        # Convert one-hot encoded labels to class indices if needed\n",
0814         "        if len(labels.shape) > 1:\n",
0815         "            labels = torch.argmax(labels, dim=1)\n",
0816         "        # Compute mean accuracy\n",
0817         "        accuracy = (predicted == labels).float().mean().item()\n",
0818         "    \n",
0819         "    return accuracy\n",
0820         "\n",
0821         "# Compute baseline accuracy\n",
0822         "baseline_accuracy = model_accuracy(input_features_tensor, labels_tensor, model)\n",
0823         "print(f\"Baseline accuracy: {baseline_accuracy:.4f}\")\n",
0824         "\n",
0825         "# Initialize array to store feature importances\n",
0826         "feature_importances = np.zeros(input_features_tensor.shape[1])\n",
0827         "\n",
0828         "# Iterate over each feature for permutation importance\n",
0829         "for i in range(input_features_tensor.shape[1]):\n",
0830         "    # Create a copy of the original features\n",
0831         "    permuted_features = input_features_tensor.clone()\n",
0832         "    \n",
0833         "    # Permute feature i across all examples\n",
0834         "    # We do this by shuffling the rows for that specific column\n",
0835         "    permuted_features[:, i] = permuted_features[torch.randperm(permuted_features.size(0)), i]\n",
0836         "    \n",
0837         "    # Compute accuracy after permutation\n",
0838         "    permuted_accuracy = model_accuracy(permuted_features, labels_tensor, model)\n",
0839         "    \n",
0840         "    # The drop in accuracy is used as a measure of feature importance\n",
0841         "    feature_importances[i] = baseline_accuracy - permuted_accuracy\n",
0842         "\n",
0843         "# Sort features by descending importance\n",
0844         "important_features_indices = np.argsort(feature_importances)[::-1]\n",
0845         "important_features_scores = np.sort(feature_importances)[::-1]\n",
0846         "\n",
0847         "# Print out results\n",
0848         "print(\"\\nFeature importances:\")\n",
0849         "for idx, score in zip(important_features_indices, important_features_scores):\n",
0850         "    print(f\"Feature {idx} importance: {score:.4f}\")"
0851       ]
0852     },
0853     {
0854       "cell_type": "code",
0855       "execution_count": 8,
0856       "metadata": {},
0857       "outputs": [
0858         {
0859           "name": "stdout",
0860           "output_type": "stream",
0861           "text": [
0862             "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer1[32] = {\n",
0863             "-0.9152892f, 3.2650192f, -0.4164221f, -0.1210157f, -2.4165483f, -1.0984275f, -2.1654966f, -0.8991888f, -0.0503724f, 7.1305695f, -5.2781415f, 3.2997849f, 1.0025330f, -0.5117974f, 0.2957068f, -0.1811045f, -2.7853479f, 1.8040915f, -2.8807588f, -4.6462102f, 1.2869841f, -0.0526987f, 0.4946094f, 2.6554070f, -0.1360572f, 0.2122774f, 4.7361507f, -1.4605266f, 0.1759245f, -0.7966636f, -0.0401897f, -0.2652957f };\n",
0864             "\n",
0865             "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer1[14][32] = {\n",
0866             "{ 0.2570587f, 1.5017653f, -1.8436140f, 1.6314303f, -0.1464428f, 1.2261974f, 2.8629315f, -0.0778951f, -0.0007868f, -2.4665442f, 3.7231014f, -0.4062112f, 5.0222125f, -0.4256854f, -0.8145034f, -0.0993065f, 1.1874412f, 3.7737985f, -2.0898068f, 5.0041976f, -0.4184950f, 0.0133298f, -1.1757115f, 0.8953519f, -0.2589224f, 3.4567924f, -1.0867721f, -0.0325336f, -0.1398652f, 5.9361205f, -0.2938714f, 0.0110872f },\n",
0867             "{ 0.0062326f, 0.0294117f, -0.1038531f, -0.1871421f, 0.0092176f, 0.0194613f, -0.0970159f, 0.0044040f, 0.0040717f, -0.0464808f, -0.0161564f, 0.0660082f, 0.0107912f, 0.0041196f, 0.0076985f, -0.0447547f, -0.0234053f, -0.0128952f, -0.0374210f, 0.0277440f, 0.0126392f, 0.0053390f, -0.0176450f, 0.0422363f, -0.1089868f, 0.0229251f, -0.0632515f, -0.0000745f, -0.0120581f, 0.0553841f, 0.1958316f, -0.2002713f },\n",
0868             "{ -0.3374411f, 1.0331864f, -2.3923049f, -1.7857485f, 5.2973804f, -2.4907997f, -5.4545326f, -0.1601160f, -0.0028237f, 0.3510691f, 1.9307067f, 0.1516920f, -9.8718386f, -0.0893254f, -0.3546923f, 0.1746188f, -2.5511057f, -6.9032016f, 1.1837323f, -1.5620790f, 0.3867095f, -0.0118511f, 0.0633005f, -0.4638650f, -0.1623496f, -6.3706584f, -1.3387657f, -0.3824906f, 0.0013318f, -0.2500904f, 0.0434040f, 0.2572088f },\n",
0869             "{ 0.4586262f, -5.7731452f, 2.1999521f, 1.8049341f, -0.2857220f, 0.9761068f, 7.4085426f, 0.6136439f, -0.0216335f, -2.1852305f, -0.8797144f, -0.0105609f, 2.9077356f, 0.6365089f, 0.4854219f, 0.1170259f, 1.9888018f, 1.5127803f, -1.5444536f, 8.0876036f, 0.2033372f, -0.0132441f, -0.8586000f, -1.9558586f, -0.0361535f, 4.7021170f, 1.0431648f, 2.0264521f, -0.2665041f, 6.2334776f, -0.0008584f, -0.0026126f },\n",
0870             "{ -1.5857489f, 2.3446348f, -14.6416082f, 2.6467063f, 3.0982802f, 14.1958466f, 7.0268164f, -0.4456313f, 0.0005447f, -1.3794490f, -10.3501320f, 1.1612288f, 9.6774111f, -0.1875248f, 21.0353413f, -0.1961844f, 8.0234823f, -4.1653094f, -7.1429234f, 15.8104372f, 2.7506628f, -0.0142524f, 22.4932308f, 0.7682357f, 0.0385473f, 13.5405970f, -4.9976201f, -22.4438667f, -0.1693429f, -18.7507858f, -0.0939464f, 0.2192265f },\n",
0871             "{ 65.2134705f, 1.4352289f, -0.6685436f, 22.9858570f, 1.6736605f, -0.1810708f, 0.5204540f, -53.2590408f, -0.0630155f, 3.6024182f, -3.8777969f, 5.0021510f, 0.0055030f, 23.8294449f, -1.3818942f, -0.2419317f, -10.2253504f, -1.8309352f, 1.7169305f, -0.3938941f, 9.1144180f, -0.0004920f, -3.1774588f, -36.4919891f, -0.1711030f, 6.5288949f, 4.5861993f, 0.8314257f, 0.0305954f, 5.1864023f, 0.2658210f, -0.1748345f },\n",
0872             "{ -0.7423377f, -5.8733273f, 5.0070434f, -6.5473375f, 3.8788877f, 6.8001447f, 3.3014266f, -0.5657389f, -0.0376282f, 6.1230965f, 4.0765481f, 0.1596697f, -7.4254904f, 0.2356068f, 3.5560462f, -0.1223621f, 0.7022666f, -9.2908258f, -4.4684458f, -1.1861130f, 0.1879538f, -0.0337607f, 0.9228330f, 1.2672698f, -0.1690857f, -11.2556086f, 4.1028724f, 0.4850706f, 0.0041372f, 3.7200036f, 0.1445599f, 0.2449260f },\n",
0873             "{ -2.3848257f, -1.7144806f, -0.2987937f, 10.1947727f, 0.7855392f, 7.1466241f, -2.2256043f, -1.2184218f, 0.1233135f, 5.3274498f, 2.9086673f, 1.5096599f, 2.8449385f, -0.2345320f, -2.2044909f, 0.1858539f, -0.7592235f, 3.1651254f, 0.5184333f, -3.7233777f, 1.5772278f, 0.0997663f, -0.0325775f, -13.2207737f, -0.0340279f, -6.4953661f, -18.9173355f, 0.9963044f, 0.1927230f, 4.6532283f, -0.0916147f, -0.0406466f },\n",
0874             "{ 5.0975156f, -0.7078310f, 26.7917671f, -1.9278797f, -2.4459264f, -12.1174421f, -5.1347051f, 2.0090365f, -0.0012259f, -11.6201696f, 24.7306499f, -8.7715597f, 4.8136749f, -0.7106145f, -46.1458054f, 0.1771528f, 22.9087524f, -5.5876012f, 6.9944615f, 29.2786064f, -7.2195830f, 0.0270186f, 14.7860146f, 1.0168871f, 0.1467975f, -19.6260185f, 3.3284237f, 22.5860500f, -0.0160821f, -35.1570702f, 0.1473373f, 0.1500054f },\n",
0875             "{ 53.6092033f, -0.7677985f, 3.3910768f, -27.1046524f, -5.1561117f, 1.6982658f, 0.9386115f, -72.5867996f, 57.8674583f, -8.2815771f, 3.2216380f, -28.9387760f, 3.2793560f, 37.3099365f, 2.2979465f, 0.1827718f, 6.8113675f, 1.0104842f, -1.5407079f, -2.5780137f, 32.4788666f, -61.8150520f, -2.6497467f, -10.6412830f, -0.2596186f, -11.8458385f, 27.5528336f, -1.3142428f, -0.2566442f, -17.8737431f, 0.0261727f, 0.0107839f },\n",
0876             "{ 0.2419503f, -3.8581870f, 6.1144238f, 2.3472424f, 1.9470867f, -3.3741038f, 0.4852004f, 0.6366135f, -0.0736884f, 2.2963598f, -0.4113644f, -2.0738223f, -3.7331092f, 0.7157578f, 1.2168316f, -0.1584605f, -2.3843584f, -6.1547771f, -0.4764193f, 4.6278925f, -1.3195806f, -0.0717061f, 2.5889101f, -3.7769084f, 0.0527631f, 3.8808305f, -0.0672685f, 0.3294500f, -0.1916338f, -2.2346771f, -0.1518883f, 0.0462940f },\n",
0877             "{ 0.7453216f, -0.5999010f, -2.6196895f, -6.5323448f, 0.0482639f, 0.0162446f, -0.1185504f, -2.6736362f, -0.0037108f, -4.2818441f, 4.1449761f, 3.3861248f, 0.1735097f, -4.7952204f, 0.8002076f, 0.0598137f, 0.2611136f, 2.4648571f, 1.3178054f, -12.6864462f, -1.4815618f, -0.0113616f, 0.0697572f, 11.0613089f, -0.1636935f, -1.3598750f, 5.4537063f, 0.4077346f, 0.2614925f, 0.4335817f, -0.1616396f, 0.0372773f },\n",
0878             "{ 0.3573107f, -0.1230106f, -0.0517133f, -0.7743639f, 0.1088887f, -0.0315369f, -0.4702122f, 0.4170579f, 0.0149554f, -2.9016135f, 0.8456838f, -1.4586754f, -0.3096066f, 0.3871936f, -0.0811070f, -0.0972313f, 1.3539683f, -0.4489806f, 2.1791372f, -0.0245313f, -0.5678235f, 0.0153700f, 0.0444115f, -1.1144291f, -0.0992134f, -0.0615626f, -1.5467011f, 0.3384680f, -0.2377923f, -3.0146224f, -0.1680345f, -0.0683730f },\n",
0879             "{ 17.0918045f, 0.3651829f, 0.4609411f, -8.8885517f, -1.3358241f, 0.3141162f, 0.5917582f, -3.1579833f, 18.4088402f, -0.7021288f, -0.1767638f, -20.7704201f, 1.0183893f, -12.4671431f, 0.0741675f, 0.2120477f, -0.4298171f, -0.3993036f, -0.4320501f, 0.1840025f, 35.6576691f, -19.6535892f, -1.1798559f, -4.8292837f, 0.1928347f, -2.1754487f, 5.7580199f, -0.4750445f, -0.0005913f, -3.2222869f, -0.0762974f, -0.0288493f },\n",
0880             "};\n",
0881             "\n",
0882             "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_layer2[32] = {\n",
0883             "-2.9135964f, 2.8539734f, 0.1411822f, 1.3484665f, 4.0785451f, 3.1302166f, 0.6947064f, -6.6154590f, 0.5603381f, 1.1026498f, -0.0329598f, 5.5717101f, -4.4454126f, 0.8731447f, -0.0039664f, 3.3978446f, 0.2816379f, 1.0174516f, 5.7364573f, -0.2107503f, 1.5612308f, 3.0443959f, -0.5723908f, -2.2100310f, 3.2695763f, 1.2092905f, 0.3386045f, 2.9886062f, 1.6525602f, 2.2572896f, 1.5943902f, 3.8117774f };\n",
0884             "\n",
0885             "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_layer2[32][32] = {\n",
0886             "{ -34.7392120f, -3.6741450f, 0.9912871f, 2.2665858f, -4.2091088f, -5.7695718f, -0.7361242f, 1.9049225f, -2.0010724f, 0.5481573f, -0.1977228f, -8.8578167f, 0.2730211f, 0.2286723f, -0.2417704f, 3.0146859f, -4.6289725f, -1.9597307f, -3.8090327f, -0.0440762f, -0.1491523f, 7.2579861f, -1.3858653f, 5.7306480f, 1.3374176f, -58.5746918f, -0.7968796f, -7.3152990f, 0.8867928f, -1.0549910f, -1.5279640f, -0.3361593f },\n",
0887             "{ 0.1808980f, -0.3280856f, 0.2590686f, -1.0191135f, -0.4207653f, -0.7095064f, 0.1365926f, 0.4096814f, -10.8506107f, 1.2207726f, -0.1571288f, -0.7078422f, 5.8424225f, 1.6238457f, -0.0862379f, -1.4228019f, 0.4024760f, 1.8243361f, -0.6990436f, 0.0060351f, -0.1591629f, -4.6918707f, -0.1512203f, 0.0915786f, -0.5897132f, -0.2044267f, -0.9305406f, -0.5472224f, -2.8864694f, 1.0543642f, 0.5347654f, -0.5453511f },\n",
0888             "{ 0.0880480f, 0.4925799f, -2.0632777f, 0.7970945f, 0.1767638f, 2.1716123f, -1.6498836f, 2.0251210f, 3.4944487f, 4.9597678f, 0.0147583f, 1.4637951f, -1.9777663f, -0.4427343f, -0.2699612f, 0.2240291f, 1.4522513f, 4.8733187f, 1.2140924f, -0.1049215f, -0.6172752f, -1.0545660f, -0.0994265f, 0.0489118f, 0.1526600f, -0.5299308f, 1.2835877f, 1.1346543f, 3.8537612f, 0.0424094f, 0.6743001f, 3.6908419f },\n",
0889             "{ 13.5742407f, -0.2112840f, -0.1391970f, -0.2532290f, -6.1313200f, -0.0358473f, -1.1182036f, -1.2637638f, -0.9513455f, -1.3301764f, -0.0799558f, -1.3724291f, -1.6102096f, -0.0269820f, -0.2672046f, -4.3722496f, 0.5367125f, -0.8932595f, 2.7968166f, -0.1038471f, 0.9354176f, 11.1879988f, 2.5031428f, -0.5350324f, 0.9403298f, -23.2240810f, -1.7667898f, -0.9982644f, 0.7523674f, 0.1651796f, 0.6430404f, 0.1910793f },\n",
0890             "{ 0.2658246f, 0.7314553f, -0.1396945f, 0.8241134f, -0.0880722f, 0.5978789f, 0.6997685f, -0.6765627f, 2.1706738f, -0.8361687f, -0.0501303f, 0.3658123f, -6.4972825f, -4.6433911f, -0.0610007f, -1.3655313f, 0.8686107f, -0.5963267f, 0.0234031f, -0.2062150f, -0.2529819f, -8.4623156f, -0.4437916f, -1.9638975f, -0.3870836f, -0.5850767f, -4.6490083f, 0.6830147f, -1.2122021f, -2.5658479f, 0.4557288f, -0.0879869f },\n",
0891             "{ -0.0108578f, 0.5193277f, 0.8980315f, 0.6727697f, -0.2764051f, 0.5845009f, 0.2578916f, 0.5809379f, 0.1509607f, 0.6870583f, -0.1276244f, -0.0574215f, -4.9936194f, 0.5157762f, -0.2816216f, -1.3928800f, 0.4075674f, 1.1707770f, 0.0359891f, -0.0638050f, -0.7958111f, -13.7506857f, 0.9970800f, 0.3291980f, -0.0093937f, -0.8198650f, 2.4517658f, 0.0728102f, 0.5025729f, 1.1547782f, -2.2815955f, 0.7980155f },\n",
0892             "{ -39.5117416f, -1.9099823f, -0.5987001f, -0.1347036f, 0.2869921f, -6.9140315f, 1.1130731f, 1.5080519f, 2.2398031f, -0.0088237f, -0.0412525f, -6.1471343f, 2.8918502f, -0.3072237f, -0.1006490f, -1.7464429f, -0.8411779f, -42.2267303f, 0.7191703f, -0.1617323f, -0.2711441f, -0.8011407f, 0.5901257f, 1.3762408f, 0.3728105f, 0.5847842f, -0.7151731f, -5.5479650f, 0.4756111f, 0.1937995f, -0.2577415f, 1.4683112f },\n",
0893             "{ 5.5796323f, -1.9758234f, -3.9855742f, -1.1933511f, -6.1229320f, -3.4628909f, 0.8108729f, 1.1214327f, 0.2925960f, -0.3177399f, -0.0674356f, -7.5650120f, -0.2327600f, -0.1809894f, -0.2001229f, -3.7608955f, -1.6911368f, -1.2986628f, -10.9993305f, -0.0333699f, 2.1555903f, 4.0816026f, -1.7761034f, 2.7839565f, 1.7967551f, -1.9129136f, 0.3337919f, -5.5511317f, -0.3778784f, -1.7885938f, 0.5358973f, 0.3659674f },\n",
0894             "{ -0.4947478f, -1.1247686f, -0.5239698f, -0.2116186f, 0.9495451f, -1.7275617f, 3.2058396f, 5.8543334f, 1.1329687f, 2.0362375f, -0.0113123f, -3.6117337f, -2.6473372f, -101.7068253f, 0.0933132f, -9.5351706f, 2.7397399f, -6.7515664f, 6.8541646f, -0.1772990f, -8.0234919f, -11.7267113f, -79.2402649f, 1.0556825f, -1.2873408f, -204.0993347f, 2.9009213f, -1.5086528f, -0.3790632f, -16.6444530f, -0.7742234f, -2.5006552f },\n",
0895             "{ -3.5879157f, 1.2350843f, 0.9988346f, 0.8656718f, -0.2288211f, 1.5041313f, 1.2635980f, -0.9575849f, 1.4612056f, -0.6694647f, -0.2607012f, 2.2113328f, -0.1756403f, 0.4053871f, -0.2502097f, -0.9793193f, 1.1052762f, 0.1973925f, -0.6196574f, -0.1616422f, -0.0878458f, -33.2044487f, -0.6067256f, -2.6461327f, -0.8652852f, 0.4787520f, 0.0987720f, 2.0803480f, -2.9416194f, 0.6944812f, -0.7692280f, -0.2723561f },\n",
0896             "{ 0.9679711f, -0.8438777f, 1.6422297f, 0.1470843f, -0.0370339f, -1.7256715f, -3.5553441f, 0.9303662f, -3.7400429f, -0.0415591f, -0.0421884f, -0.4996702f, 0.2560562f, -1.1471199f, -0.1010781f, 2.9399052f, -2.7095137f, 0.1029265f, -0.9041355f, -0.0961986f, -1.1163988f, 2.2019930f, 1.1136653f, 2.6281602f, 0.1038788f, 0.8405868f, -1.7726818f, -1.4164388f, 2.5241690f, -0.4696356f, -0.6558685f, -0.3892841f },\n",
0897             "{ 0.3616795f, -0.2340108f, 1.3857211f, -0.1612080f, 0.9073119f, 2.0731826f, -2.9476249f, -0.8346572f, -0.2136685f, -1.5312109f, -0.2896508f, 3.2546318f, -0.8377571f, -0.1631008f, -0.2456917f, 2.3952994f, 0.7205132f, 0.3009275f, -0.8928477f, 0.0297560f, -0.9989902f, -5.8525624f, -1.1516922f, -1.4669514f, -0.9353167f, 0.0053586f, -0.2916538f, 2.6023183f, -0.3268352f, 1.1982346f, 0.5936528f, 0.2427792f },\n",
0898             "{ -8.4515934f, 1.6914611f, 1.8740683f, -3.0301368f, 1.1796944f, 1.6180993f, 1.5193824f, -1.3537951f, 1.8400136f, -2.8482721f, -0.1556847f, 1.9412404f, 4.7879019f, 0.7858805f, 0.0379962f, -2.2685356f, 2.4616067f, -3.1357207f, -3.1058917f, -0.0751680f, -6.3824587f, -6.4688945f, -0.4577174f, -4.3322401f, -5.3993649f, -0.0399034f, 0.6397164f, 2.3432400f, -3.4454954f, 0.4422087f, 2.4481769f, -2.0406303f },\n",
0899             "{ 2.0977514f, 2.0378633f, -5.0659881f, -3.1632454f, -3.7596478f, 1.1832924f, 2.6347756f, -0.7441098f, 0.9281505f, 0.2330403f, -0.1830614f, -0.7371528f, 0.9002755f, 0.2897577f, 0.0216177f, -6.7740455f, 2.5610964f, 0.7834055f, -7.0665541f, -0.0497221f, 2.1334522f, -4.0962648f, -0.9172248f, -1.9944772f, 0.2762617f, -7.6493464f, 1.3044875f, 0.1891927f, -1.2570183f, -0.9203181f, 1.1330502f, 0.2542208f },\n",
0900             "{ -3.2519467f, 0.1621506f, 1.2959863f, -0.4015787f, -0.1872185f, -1.3774366f, -0.6622453f, -0.9865818f, -0.7389784f, -2.5770009f, -0.1067092f, -1.8156646f, 8.1435080f, 0.1961913f, -0.2026473f, -0.5137898f, -0.3847715f, -2.7479975f, -0.5669084f, -0.0805389f, 0.8162143f, -89.8910904f, -0.0983714f, -0.6288267f, -0.1096447f, 0.3820810f, -0.5419686f, -1.2580637f, -0.1827327f, -0.7081568f, -0.0991779f, -1.8233751f },\n",
0901             "{ -0.1124980f, -0.0672807f, -0.1147610f, 0.0497073f, 0.1218153f, 0.1264220f, 0.0282118f, 0.1016580f, 0.0696730f, -0.0004516f, 0.0812953f, 0.0875243f, 0.0982849f, 0.1660293f, -0.1481024f, -0.0315878f, -0.0644747f, 0.1398649f, 0.0835874f, -0.1440294f, -0.1390193f, -0.0628670f, -0.1517032f, -0.0325693f, -0.1094708f, 0.0963070f, 0.0056602f, -0.0197677f, -0.0068012f, 0.1578562f, -0.0302607f, 0.0684079f },\n",
0902             "{ 0.2561692f, 0.6044238f, -1.0067526f, 3.0207021f, -0.5215638f, -1.5455573f, -2.4320843f, -0.2874290f, -5.5609145f, -2.5270512f, -0.1884816f, -1.4440522f, 1.1501647f, 1.2767099f, -0.2626259f, 1.5462712f, -0.3342260f, -1.4259053f, -0.1591775f, -0.1777169f, -0.8070273f, -0.6262614f, -0.1421982f, 0.4950019f, 0.3899588f, 1.1158458f, -1.8252333f, -1.4090738f, -0.9128270f, 1.2212859f, -0.5060611f, -1.5151979f },\n",
0903             "{ 6.7118378f, -1.9413553f, -0.9765822f, 1.9195900f, -1.6302279f, -3.3607423f, -0.5215842f, 1.6841047f, -3.7323046f, 2.9130237f, -0.2912694f, -1.6349151f, -2.9017780f, 0.8473414f, -0.0895011f, 1.9765211f, -1.6982929f, 1.3711845f, -0.8770422f, -0.0966949f, -4.9838910f, -80.1754532f, 0.5617938f, 5.0638437f, -2.1738896f, 1.1080216f, -1.2562524f, -2.4832511f, 1.9475292f, -0.0768876f, -1.7405595f, 1.2659055f },\n",
0904             "{ -9.3639793f, -0.7770892f, 0.5863696f, -2.4971461f, 0.9785317f, 0.6006397f, -0.7508336f, 1.4561496f, 3.3278019f, -1.3552047f, 0.0016642f, -1.0065420f, -0.9822129f, -0.1398876f, -0.0867651f, -2.2540245f, 1.3095651f, -0.4880162f, -0.9081959f, 0.0172203f, -0.9673725f, -6.6905494f, 1.8820195f, -0.3343536f, -0.9252734f, -0.4198346f, 3.2226353f, 0.0417345f, -0.2720280f, -0.4798162f, 0.8319970f, 0.6051134f },\n",
0905             "{ 0.0971739f, -0.1656290f, -3.0162272f, 1.3674160f, -0.1130774f, -2.0347462f, 0.8287644f, 1.7546188f, -0.8183546f, -0.2517924f, -0.0338358f, -2.3002670f, -1.8175446f, -0.8929357f, 0.0145055f, 0.7129369f, -1.2497267f, -0.8694588f, 0.6608886f, 0.0472852f, 0.7463530f, 2.0970786f, -0.7406388f, 1.6379862f, 0.9597634f, 0.2664887f, -0.9611518f, -2.1256742f, -0.8108851f, -0.7670876f, -0.2143202f, -0.9296412f },\n",
0906             "{ -5.5468359f, 1.8933475f, 0.5377558f, -0.8609743f, 1.6522382f, 3.7070215f, -1.3910099f, -1.7996042f, -1.2547578f, -0.3161051f, -0.1433857f, 5.9167895f, 0.2788936f, 0.6513762f, -0.1890229f, 1.3976518f, 2.5647829f, 1.7091097f, -2.4891980f, 0.0704016f, -1.2354076f, -7.9673457f, 0.5024176f, -3.1194675f, -1.8552034f, 1.4241451f, -0.5721908f, 4.6941881f, -0.4191083f, 1.5897839f, 0.5376836f, -0.3906031f },\n",
0907             "{ -7.3998523f, -1.7208674f, 3.6660385f, 3.2399852f, 2.6726487f, -2.7743144f, 2.6148691f, 5.4286017f, -3.5616024f, 3.8747442f, -0.2854572f, -1.7255206f, 1.5527865f, -95.2269287f, -0.0130005f, -0.3984787f, -0.3650612f, -5.9493575f, 4.3472433f, 0.0598797f, -11.5429420f, -2.9780169f, -69.4482956f, 3.8544486f, -2.9926283f, 3.9207959f, 0.2614698f, -1.5368384f, 1.2026052f, -10.9552374f, -2.5336740f, -4.0654378f },\n",
0908             "{ 2.1074338f, -2.0316105f, -0.2519890f, 0.0232255f, -0.2889173f, -0.1693905f, 0.4285150f, 0.6449150f, 4.6293564f, 3.3936121f, 0.0660587f, 0.3134870f, -2.3245549f, -0.9685450f, 0.0889201f, 0.5934311f, -0.9143091f, 2.3384421f, -0.4089483f, -0.1643694f, 2.5919075f, 6.0844655f, 0.2091536f, 2.0565152f, 1.1226704f, -0.2695110f, 0.3927289f, -0.0457220f, 4.0436058f, -1.5475131f, -1.3438802f, 1.6676767f },\n",
0909             "{ 1.7787290f, 0.3969800f, 1.7834123f, 1.3779058f, -0.5738219f, -0.5349790f, -1.4947498f, -0.0787759f, 0.0341407f, 0.4346433f, -0.1981957f, -0.2886125f, -1.0133898f, -0.7178908f, -0.1872994f, 0.7944858f, -1.4787679f, -0.2754479f, -3.5224824f, -0.2090070f, -1.1161381f, -3.6711519f, -1.7022727f, 0.1558363f, 0.4152904f, -2.6727507f, 1.0731899f, -0.3006089f, 0.1950178f, -1.4062581f, -1.4458562f, 0.2443156f },\n",
0910             "{ -0.1570787f, 0.0073413f, -0.1335499f, -0.1712597f, 0.0029127f, 0.1628582f, -0.0816609f, -0.1307715f, -0.1621102f, -0.1200016f, -0.1394555f, 0.0157797f, -0.1572116f, 0.1745119f, -0.1128801f, -0.0566642f, 0.0099119f, -0.1222350f, 0.0299575f, -0.1031234f, -0.1048335f, 0.1707117f, -0.1490631f, -0.0835587f, -0.1712185f, -0.1278749f, 0.1462234f, -0.0081762f, -0.1106477f, -0.1645385f, 0.1268658f, 0.1686065f },\n",
0911             "{ -2.2590189f, -0.1024268f, -1.9020500f, -0.7051241f, 0.1037211f, 0.1701183f, 0.1889226f, 0.9506961f, 1.4137075f, 0.4496680f, -0.2055015f, -0.5990168f, -6.5227470f, -0.1113795f, -0.1070101f, 0.0105921f, -0.1653819f, 0.8838411f, -0.4713951f, 0.0250525f, 0.5694079f, -63.6874619f, 0.3740432f, 0.2925327f, 0.2328388f, -0.9265066f, 0.3290201f, -0.3581912f, 0.8044130f, -0.0143339f, 0.6609334f, -0.6653876f },\n",
0912             "{ 1.4302264f, 0.2180429f, 0.9684587f, 1.0369793f, 0.1597588f, -0.7066790f, -1.7150335f, 0.1960071f, -0.1694691f, 0.8381965f, -0.0181254f, -1.8366945f, -1.8840518f, -0.3109443f, -0.0058080f, 2.0794308f, -1.7936089f, -0.4478118f, -1.2889514f, -0.0300996f, -0.5915933f, -0.8868528f, 1.2223560f, 0.6542705f, 0.0814525f, -1.3704894f, -0.1875549f, -1.6079675f, -0.2744675f, 0.0382733f, -0.9821799f, -1.1006635f },\n",
0913             "{ -0.3911386f, 0.0468989f, 1.9009087f, -1.6725038f, 0.4506643f, -1.9519551f, 0.8855276f, -1.5861524f, 0.3190332f, -3.1724985f, -0.0278042f, -1.2427157f, 1.6820471f, 0.1633015f, -0.0449006f, -1.6101545f, -0.1007412f, -2.7659879f, -0.5162025f, -0.1431058f, 0.8236251f, -0.9194202f, -0.1490582f, -1.6231275f, -0.5467592f, 0.1333764f, -0.4865468f, -0.8269796f, -0.9018597f, 0.0288408f, -1.0994427f, -2.7987468f },\n",
0914             "{ 0.1278387f, -0.0134571f, 0.0448454f, -0.1556552f, -0.1247998f, -0.1196313f, -0.1611872f, -0.0630336f, -0.1410615f, 0.1682803f, 0.0263861f, 0.0619852f, 0.0423437f, 0.0982059f, 0.0784064f, 0.1412098f, 0.0331818f, -0.1537199f, -0.1165343f, -0.0441304f, 0.0197925f, -0.1256299f, 0.0694126f, -0.0137156f, 0.1587864f, 0.1131037f, -0.0722358f, 0.1287198f, -0.0683723f, -0.1212666f, 0.0847685f, 0.1469466f },\n",
0915             "{ -1.7708958f, 0.2500555f, -1.0356978f, -1.4820993f, -0.9565629f, 2.2127695f, 1.3409303f, -0.6528702f, 1.4306690f, 1.7529922f, 0.0491593f, 0.8595024f, -1.1016617f, -0.1608696f, -0.1200257f, -1.9610568f, 2.6189950f, 1.8707203f, 0.5241567f, -0.2288505f, 0.1528303f, -127.8296814f, -166.2449646f, -1.0174949f, -0.6682033f, -0.9169813f, 1.5756097f, 1.4574182f, -0.1463246f, -1.8262713f, 0.7517605f, -0.1181977f },\n",
0916             "{ -0.0324178f, -0.0418596f, -0.1287051f, -0.0232098f, -0.0512466f, -0.0905093f, -0.1104402f, -0.0095842f, 0.1413968f, -0.0081470f, -0.0251773f, 0.0667293f, 0.0344667f, 0.0116366f, -0.0908088f, -0.0980062f, 0.1874590f, -0.0381802f, 0.0684232f, 0.0252469f, -0.0681347f, 0.1034415f, 0.0576827f, -0.0557779f, 0.0868192f, -0.0851723f, -0.0868760f, 0.1192429f, 0.1751331f, -0.0323825f, -0.1238438f, -0.0623215f },\n",
0917             "{ 0.1757070f, -0.1212057f, -0.0878934f, 0.0737142f, 0.0712249f, -0.0818311f, -0.0719173f, -0.0561241f, 0.0630706f, -0.1523757f, -0.0048847f, 0.1597463f, -0.0302248f, -0.0096164f, -0.0259278f, -0.0815664f, -0.1283869f, 0.1644790f, -0.1612884f, 0.1505984f, -0.1614616f, -0.0756450f, -0.1680063f, -0.0716024f, -0.1266488f, 0.1165592f, -0.0066066f, 0.0661669f, 0.0148620f, 0.0464089f, 0.1496351f, -0.1720888f },\n",
0918             "};\n",
0919             "\n",
0920             "ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_output_layer[3] = {\n",
0921             "-0.3838706f, -0.0366794f, 0.5841699f };\n",
0922             "\n",
0923             "ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_output_layer[32][3] = {\n",
0924             "{ 0.6237589f, 0.2710748f, 0.5615537f },\n",
0925             "{ -0.1665458f, 0.3942705f, 0.2601272f },\n",
0926             "{ 0.3388835f, 0.1579971f, 0.0178280f },\n",
0927             "{ 0.5823844f, -0.0299621f, 0.1178701f },\n",
0928             "{ 0.5561634f, 0.1805784f, 0.6629463f },\n",
0929             "{ 0.1693098f, -0.8297758f, 0.1556239f },\n",
0930             "{ 0.0062806f, 0.2958559f, 0.2698825f },\n",
0931             "{ -0.3925241f, 0.1489681f, -0.0803940f },\n",
0932             "{ 0.5710047f, 0.1924859f, 0.2375189f },\n",
0933             "{ -0.0372825f, 0.0286687f, 0.2910011f },\n",
0934             "{ -0.0867018f, -0.1508995f, -0.0193411f },\n",
0935             "{ 0.4878173f, -0.9407690f, 0.3869846f },\n",
0936             "{ 0.9613981f, 0.3148000f, 0.2196945f },\n",
0937             "{ 0.5831478f, 1.2141191f, 0.7358299f },\n",
0938             "{ -0.0073579f, -0.0419888f, 0.0338354f },\n",
0939             "{ 0.2477632f, 0.9092489f, 0.7818094f },\n",
0940             "{ 0.3554717f, -0.4452990f, 0.0102171f },\n",
0941             "{ 0.3888267f, 0.7089493f, 0.3766315f },\n",
0942             "{ 0.8450955f, -0.0079020f, 0.5853269f },\n",
0943             "{ 0.0646952f, 0.0271975f, 0.0329916f },\n",
0944             "{ 0.5528679f, 0.0075829f, 0.2414524f },\n",
0945             "{ -1.3869698f, -1.1617719f, -1.1356672f },\n",
0946             "{ 0.0214099f, 0.3563140f, 0.5346315f },\n",
0947             "{ 0.3791857f, -0.2714695f, -0.0823861f },\n",
0948             "{ -0.3221727f, 0.5334318f, 0.1581419f },\n",
0949             "{ 0.6678535f, 0.6672282f, 0.4110478f },\n",
0950             "{ 0.1442596f, 0.0245941f, -0.1659890f },\n",
0951             "{ -0.9674007f, 1.4712439f, -0.8418093f },\n",
0952             "{ 0.5696401f, 0.2636259f, 0.2079044f },\n",
0953             "{ 0.0382360f, 0.2687068f, 0.4462553f },\n",
0954             "{ -0.0957586f, 0.4259349f, 0.3613387f },\n",
0955             "{ -0.0633585f, 0.4451550f, 0.2848748f },\n",
0956             "};\n",
0957             "\n"
0958           ]
0959         }
0960       ],
0961       "source": [
0962         "def print_formatted_weights_biases(weights, biases, layer_name):\n",
0963         "    # Print biases\n",
0964         "    print(f\"ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_{layer_name}[{len(biases)}] = {{\")\n",
0965         "    print(\", \".join(f\"{b:.7f}f\" for b in biases) + \" };\")\n",
0966         "    print()\n",
0967         "\n",
0968         "    # Print weights\n",
0969         "    print(f\"ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_{layer_name}[{len(weights[0])}][{len(weights)}] = {{\")\n",
0970         "    for row in weights.T:\n",
0971         "        formatted_row = \", \".join(f\"{w:.7f}f\" for w in row)\n",
0972         "        print(f\"{{ {formatted_row} }},\")\n",
0973         "    print(\"};\")\n",
0974         "    print()\n",
0975         "\n",
0976         "def print_model_weights_biases(model):\n",
0977         "    # Make sure the model is in evaluation mode\n",
0978         "    model.eval()\n",
0979         "\n",
0980         "    # Iterate through all named modules in the model\n",
0981         "    for name, module in model.named_modules():\n",
0982         "        # Check if the module is a linear layer\n",
0983         "        if isinstance(module, nn.Linear):\n",
0984         "            # Get weights and biases\n",
0985         "            weights = module.weight.data.cpu().numpy()\n",
0986         "            biases = module.bias.data.cpu().numpy()\n",
0987         "\n",
0988         "            # Print formatted weights and biases\n",
0989         "            print_formatted_weights_biases(weights, biases, name.replace('.', '_'))\n",
0990         "\n",
0991         "print_model_weights_biases(model)\n"
0992       ]
0993     },
0994     {
0995       "cell_type": "code",
0996       "execution_count": 9,
0997       "metadata": {},
0998       "outputs": [],
0999       "source": [
1000         "# Ensure input_features_tensor is moved to the appropriate device\n",
1001         "input_features_tensor = input_features_tensor.to(device)\n",
1002         "\n",
1003         "# Make predictions\n",
1004         "with torch.no_grad():\n",
1005         "    model.eval()\n",
1006         "    outputs = model(input_features_tensor)\n",
1007         "    predictions = outputs.squeeze().cpu().numpy()\n",
1008         "\n",
1009         "full_tracks = (np.concatenate(branches['t3_pMatched']) > 0.95)\n",
1010         "\n",
1011         "t3_pt = np.concatenate(branches['t3_radius']) * 2 * (2.99792458e-3 * 3.8) / 2"
1012       ]
1013     },
1014     {
1015       "cell_type": "code",
1016       "execution_count": 10,
1017       "metadata": {},
1018       "outputs": [
1019         {
1020           "name": "stdout",
1021           "output_type": "stream",
1022           "text": [
1023             "Eta bin 0.00-0.25: 9409714 fakes, 313231 true Prompt\n",
1024             "Eta bin 0.25-0.50: 9242595 fakes, 323051 true Prompt\n",
1025             "Eta bin 0.50-0.75: 7849380 fakes, 410185 true Prompt\n",
1026             "Eta bin 0.75-1.00: 4293980 fakes, 322065 true Prompt\n",
1027             "Eta bin 1.00-1.25: 4343023 fakes, 374215 true Prompt\n",
1028             "Eta bin 1.25-1.50: 2725728 fakes, 351420 true Prompt\n",
1029             "Eta bin 1.50-1.75: 1368266 fakes, 425819 true Prompt\n",
1030             "Eta bin 1.75-2.00: 1413754 fakes, 467604 true Prompt\n",
1031             "Eta bin 2.00-2.25: 448439 fakes, 419450 true Prompt\n",
1032             "Eta bin 2.25-2.50: 124212 fakes, 247704 true Prompt\n"
1033           ]
1034         },
1035         {
1036           "data": {
1037             "text/plain": [
1038               "<Figure size 2000x800 with 3 Axes>"
1039             ]
1040           },
1041           "metadata": {},
1042           "output_type": "display_data"
1043         },
1044         {
1045           "name": "stdout",
1046           "output_type": "stream",
1047           "text": [
1048             "\n",
1049             "Prompt tracks, pt: 0.0 to 5.0 GeV\n",
1050             "Number of true prompt tracks: 3654744\n",
1051             "Number of fake tracks in pt bin: 41219091\n",
1052             "\n",
1053             "80% Retention Cut Values: {0.6326, 0.6415, 0.6506, 0.6526, 0.5321, 0.5651, 0.5747, 0.5765, 0.6243, 0.6320} Mean: 0.6082\n",
1054             "80% Cut Fake Rejections: {99.1, 99.1, 99.1, 99.2, 97.4, 97.1, 96.8, 96.8, 95.6, 92.4} Mean: 97.3%\n",
1055             "\n",
1056             "90% Retention Cut Values: {0.4957, 0.5052, 0.5201, 0.5340, 0.4275, 0.4708, 0.4890, 0.4932, 0.5400, 0.5449} Mean: 0.502\n",
1057             "90% Cut Fake Rejections: {98.5, 98.6, 98.6, 98.7, 95.5, 95.4, 95.0, 95.0, 93.1, 88.9} Mean: 95.7%\n",
1058             "\n",
1059             "93% Retention Cut Values: {0.3874, 0.3978, 0.4201, 0.4502, 0.3750, 0.4117, 0.4394, 0.4466, 0.4916, 0.4899} Mean: 0.431\n",
1060             "93% Cut Fake Rejections: {98.1, 98.2, 98.2, 98.3, 94.5, 94.3, 94.0, 94.1, 91.8, 86.8} Mean: 94.8%\n",
1061             "\n",
1062             "96% Retention Cut Values: {0.1332, 0.1616, 0.2094, 0.2898, 0.2909, 0.3109, 0.3561, 0.3626, 0.4016, 0.3775} Mean: 0.2894\n",
1063             "96% Cut Fake Rejections: {96.3, 96.7, 96.9, 97.1, 92.5, 92.4, 92.2, 92.4, 89.6, 82.9} Mean: 92.9%\n",
1064             "\n",
1065             "97% Retention Cut Values: {0.0894, 0.1137, 0.1595, 0.2333, 0.2507, 0.2569, 0.3081, 0.3120, 0.3466, 0.2932} Mean: 0.2363\n",
1066             "97% Cut Fake Rejections: {94.6, 95.4, 95.8, 95.9, 91.4, 91.1, 91.2, 91.4, 88.3, 79.7} Mean: 91.5%\n",
1067             "\n",
1068             "98% Retention Cut Values: {0.0595, 0.0742, 0.1129, 0.1736, 0.1919, 0.1824, 0.2337, 0.2285, 0.2529, 0.1681} Mean: 0.1678\n",
1069             "98% Cut Fake Rejections: {91.9, 93.1, 93.9, 94.0, 89.5, 89.0, 89.3, 89.5, 85.8, 73.7} Mean: 89.0%\n",
1070             "\n",
1071             "99% Retention Cut Values: {0.0319, 0.0371, 0.0493, 0.0943, 0.0923, 0.0767, 0.1061, 0.1041, 0.1124, 0.0762} Mean: 0.078\n",
1072             "99% Cut Fake Rejections: {86.5, 88.4, 88.8, 89.8, 84.8, 83.9, 84.3, 84.5, 79.6, 62.8} Mean: 83.3%\n",
1073             "\n",
1074             "99.5% Retention Cut Values: {0.0149, 0.0168, 0.0242, 0.0402, 0.0389, 0.0365, 0.0527, 0.0546, 0.0615, 0.0366} Mean: 0.0377\n",
1075             "99.5% Cut Fake Rejections: {78.7, 81.5, 83.3, 84.1, 79.2, 78.5, 78.6, 79.3, 73.6, 49.7} Mean: 76.7%\n",
1076             "Eta bin 0.00-0.25: 9409714 fakes, 43220 true Displaced\n",
1077             "Eta bin 0.25-0.50: 9242595 fakes, 47035 true Displaced\n",
1078             "Eta bin 0.50-0.75: 7849380 fakes, 62690 true Displaced\n",
1079             "Eta bin 0.75-1.00: 4293980 fakes, 52590 true Displaced\n",
1080             "Eta bin 1.00-1.25: 4343023 fakes, 62242 true Displaced\n",
1081             "Eta bin 1.25-1.50: 2725728 fakes, 59777 true Displaced\n",
1082             "Eta bin 1.50-1.75: 1368266 fakes, 76741 true Displaced\n",
1083             "Eta bin 1.75-2.00: 1413754 fakes, 90436 true Displaced\n",
1084             "Eta bin 2.00-2.25: 448439 fakes, 73564 true Displaced\n",
1085             "Eta bin 2.25-2.50: 124212 fakes, 43525 true Displaced\n"
1086           ]
1087         },
1088         {
1089           "data": {
1090             "text/plain": [
1091               "<Figure size 2000x800 with 3 Axes>"
1092             ]
1093           },
1094           "metadata": {},
1095           "output_type": "display_data"
1096         },
1097         {
1098           "name": "stdout",
1099           "output_type": "stream",
1100           "text": [
1101             "\n",
1102             "Displaced tracks, pt: 0.0 to 5.0 GeV\n",
1103             "Number of true displaced tracks: 611820\n",
1104             "Number of fake tracks in pt bin: 41219091\n",
1105             "\n",
1106             "80% Retention Cut Values: {0.2362, 0.2377, 0.2496, 0.2579, 0.2940, 0.3057, 0.3273, 0.3224, 0.2944, 0.2890} Mean: 0.2814\n",
1107             "80% Cut Fake Rejections: {83.1, 81.5, 78.7, 76.4, 75.6, 68.3, 54.0, 49.9, 34.7, 18.3} Mean: 62.1%\n",
1108             "\n",
1109             "90% Retention Cut Values: {0.1809, 0.1855, 0.1998, 0.2098, 0.2173, 0.2267, 0.2391, 0.2507, 0.2356, 0.2317} Mean: 0.2177\n",
1110             "90% Cut Fake Rejections: {79.7, 78.0, 75.4, 73.2, 70.1, 61.5, 46.5, 43.7, 29.3, 12.4} Mean: 57.0%\n",
1111             "\n",
1112             "93% Retention Cut Values: {0.1527, 0.1622, 0.1814, 0.1933, 0.1952, 0.1889, 0.2041, 0.2280, 0.2216, 0.2166} Mean: 0.1944\n",
1113             "93% Cut Fake Rejections: {77.6, 76.1, 74.0, 72.1, 68.4, 58.2, 43.5, 41.8, 28.0, 10.8} Mean: 55.0%\n",
1114             "\n",
1115             "96% Retention Cut Values: {0.1133, 0.1274, 0.1514, 0.1662, 0.1657, 0.1576, 0.1601, 0.2030, 0.2068, 0.2028} Mean: 0.1654\n",
1116             "96% Cut Fake Rejections: {73.9, 72.8, 71.5, 70.0, 65.9, 55.1, 39.2, 39.6, 26.6, 9.6} Mean: 52.4%\n",
1117             "\n",
1118             "97% Retention Cut Values: {0.0956, 0.1131, 0.1362, 0.1509, 0.1549, 0.1459, 0.1479, 0.1924, 0.2012, 0.1980} Mean: 0.1536\n",
1119             "97% Cut Fake Rejections: {71.7, 71.2, 70.0, 68.7, 65.0, 53.9, 37.8, 38.7, 26.1, 9.1} Mean: 51.2%\n",
1120             "\n",
1121             "98% Retention Cut Values: {0.0686, 0.0874, 0.1165, 0.1298, 0.1400, 0.1311, 0.1327, 0.1773, 0.1954, 0.1922} Mean: 0.1371\n",
1122             "98% Cut Fake Rejections: {67.4, 67.7, 68.0, 66.8, 63.6, 52.2, 36.0, 37.4, 25.6, 8.6} Mean: 49.3%\n",
1123             "\n",
1124             "99% Retention Cut Values: {0.0334, 0.0504, 0.0748, 0.0994, 0.1128, 0.1123, 0.1118, 0.1525, 0.1867, 0.1847} Mean: 0.1119\n",
1125             "99% Cut Fake Rejections: {57.9, 60.3, 62.3, 63.5, 60.8, 49.8, 33.2, 35.1, 24.8, 7.9} Mean: 45.6%\n",
1126             "\n",
1127             "99.5% Retention Cut Values: {0.0155, 0.0237, 0.0414, 0.0560, 0.0808, 0.0901, 0.0960, 0.1335, 0.1789, 0.1791} Mean: 0.0895\n",
1128             "99.5% Cut Fake Rejections: {47.6, 50.3, 54.7, 56.5, 56.7, 46.6, 30.8, 33.2, 24.1, 7.5} Mean: 40.8%\n",
1129             "Eta bin 0.00-0.25: 2012526 fakes, 6249 true Prompt\n",
1130             "Eta bin 0.25-0.50: 1972121 fakes, 6496 true Prompt\n",
1131             "Eta bin 0.50-0.75: 1704510 fakes, 6894 true Prompt\n",
1132             "Eta bin 0.75-1.00: 930629 fakes, 5318 true Prompt\n",
1133             "Eta bin 1.00-1.25: 861320 fakes, 9397 true Prompt\n",
1134             "Eta bin 1.25-1.50: 523329 fakes, 14695 true Prompt\n",
1135             "Eta bin 1.50-1.75: 246635 fakes, 24265 true Prompt\n",
1136             "Eta bin 1.75-2.00: 250585 fakes, 15787 true Prompt\n",
1137             "Eta bin 2.00-2.25: 86204 fakes, 6652 true Prompt\n",
1138             "Eta bin 2.25-2.50: 22080 fakes, 3385 true Prompt\n"
1139           ]
1140         },
1141         {
1142           "data": {
1143             "text/plain": [
1144               "<Figure size 2000x800 with 3 Axes>"
1145             ]
1146           },
1147           "metadata": {},
1148           "output_type": "display_data"
1149         },
1150         {
1151           "name": "stdout",
1152           "output_type": "stream",
1153           "text": [
1154             "\n",
1155             "Prompt tracks, pt: 5.0 to inf GeV\n",
1156             "Number of true prompt tracks: 99138\n",
1157             "Number of fake tracks in pt bin: 8609939\n",
1158             "\n",
1159             "80% Retention Cut Values: {0.1353, 0.1569, 0.2083, 0.2407, 0.2910, 0.3552, 0.4221, 0.4197, 0.3174, 0.2681} Mean: 0.2815\n",
1160             "80% Cut Fake Rejections: {98.8, 99.0, 99.3, 99.4, 96.1, 94.8, 93.5, 93.7, 88.5, 74.4} Mean: 93.7%\n",
1161             "\n",
1162             "90% Retention Cut Values: {0.0302, 0.0415, 0.0994, 0.1791, 0.1960, 0.2467, 0.3227, 0.3242, 0.2367, 0.2187} Mean: 0.1895\n",
1163             "90% Cut Fake Rejections: {90.2, 94.5, 98.1, 99.0, 93.4, 91.4, 89.8, 89.7, 80.1, 63.8} Mean: 89.0%\n",
1164             "\n",
1165             "93% Retention Cut Values: {0.0188, 0.0261, 0.0502, 0.1487, 0.1606, 0.2022, 0.2772, 0.2797, 0.2182, 0.2063} Mean: 0.1588\n",
1166             "93% Cut Fake Rejections: {82.0, 89.1, 94.9, 98.7, 91.8, 89.5, 87.9, 87.3, 77.5, 60.9} Mean: 85.9%\n",
1167             "\n",
1168             "96% Retention Cut Values: {0.0095, 0.0138, 0.0243, 0.0620, 0.1096, 0.1450, 0.2184, 0.2276, 0.1949, 0.1769} Mean: 0.1182\n",
1169             "96% Cut Fake Rejections: {67.7, 78.4, 86.6, 95.6, 88.5, 86.0, 84.6, 83.8, 74.2, 54.1} Mean: 79.9%\n",
1170             "\n",
1171             "97% Retention Cut Values: {0.0079, 0.0105, 0.0185, 0.0361, 0.0883, 0.1242, 0.1943, 0.2106, 0.1600, 0.1354} Mean: 0.0986\n",
1172             "97% Cut Fake Rejections: {63.7, 73.1, 82.1, 91.3, 86.5, 84.3, 83.0, 82.5, 69.0, 42.8} Mean: 75.8%\n",
1173             "\n",
1174             "98% Retention Cut Values: {0.0057, 0.0077, 0.0119, 0.0212, 0.0684, 0.0902, 0.1624, 0.1814, 0.1158, 0.1037} Mean: 0.0769\n",
1175             "98% Cut Fake Rejections: {56.9, 66.9, 74.3, 84.8, 84.0, 80.5, 80.6, 79.9, 61.0, 32.8} Mean: 70.2%\n",
1176             "\n",
1177             "99% Retention Cut Values: {0.0027, 0.0045, 0.0074, 0.0119, 0.0325, 0.0393, 0.1192, 0.1170, 0.0683, 0.0819} Mean: 0.0485\n",
1178             "99% Cut Fake Rejections: {43.1, 56.7, 65.6, 76.3, 75.4, 69.5, 76.2, 72.6, 49.8, 25.6} Mean: 61.1%\n",
1179             "\n",
1180             "99.5% Retention Cut Values: {0.0015, 0.0027, 0.0038, 0.0054, 0.0153, 0.0219, 0.0820, 0.0657, 0.0373, 0.0669} Mean: 0.0302\n",
1181             "99.5% Cut Fake Rejections: {34.3, 48.3, 54.4, 63.1, 64.6, 60.9, 70.6, 63.0, 39.3, 20.4} Mean: 51.9%\n",
1182             "Eta bin 0.00-0.25: 2012526 fakes, 2764 true Displaced\n",
1183             "Eta bin 0.25-0.50: 1972121 fakes, 2581 true Displaced\n",
1184             "Eta bin 0.50-0.75: 1704510 fakes, 2477 true Displaced\n",
1185             "Eta bin 0.75-1.00: 930629 fakes, 2122 true Displaced\n",
1186             "Eta bin 1.00-1.25: 861320 fakes, 2780 true Displaced\n",
1187             "Eta bin 1.25-1.50: 523329 fakes, 3481 true Displaced\n",
1188             "Eta bin 1.50-1.75: 246635 fakes, 4701 true Displaced\n",
1189             "Eta bin 1.75-2.00: 250585 fakes, 3009 true Displaced\n",
1190             "Eta bin 2.00-2.25: 86204 fakes, 1579 true Displaced\n",
1191             "Eta bin 2.25-2.50: 22080 fakes, 881 true Displaced\n"
1192           ]
1193         },
1194         {
1195           "data": {
1196             "text/plain": [
1197               "<Figure size 2000x800 with 3 Axes>"
1198             ]
1199           },
1200           "metadata": {},
1201           "output_type": "display_data"
1202         },
1203         {
1204           "name": "stdout",
1205           "output_type": "stream",
1206           "text": [
1207             "\n",
1208             "Displaced tracks, pt: 5.0 to inf GeV\n",
1209             "Number of true displaced tracks: 26375\n",
1210             "Number of fake tracks in pt bin: 8609939\n",
1211             "\n",
1212             "80% Retention Cut Values: {0.5019, 0.4954, 0.5197, 0.5256, 0.4161, 0.3509, 0.3778, 0.3667, 0.4388, 0.4915} Mean: 0.4484\n",
1213             "80% Cut Fake Rejections: {98.6, 98.6, 98.7, 98.7, 97.0, 93.3, 87.6, 86.9, 89.0, 88.0} Mean: 93.6%\n",
1214             "\n",
1215             "90% Retention Cut Values: {0.3265, 0.3967, 0.4493, 0.4656, 0.3459, 0.2910, 0.3486, 0.3356, 0.3693, 0.4448} Mean: 0.3773\n",
1216             "90% Cut Fake Rejections: {97.6, 98.1, 98.2, 98.3, 95.8, 90.8, 85.4, 84.0, 83.3, 82.4} Mean: 91.4%\n",
1217             "\n",
1218             "93% Retention Cut Values: {0.1804, 0.2502, 0.4033, 0.4375, 0.2958, 0.2556, 0.3353, 0.3242, 0.3405, 0.4194} Mean: 0.3242\n",
1219             "93% Cut Fake Rejections: {96.1, 97.0, 97.9, 98.2, 94.6, 89.1, 84.5, 83.1, 80.3, 78.8} Mean: 90.0%\n",
1220             "\n",
1221             "96% Retention Cut Values: {0.0464, 0.0366, 0.2521, 0.3317, 0.2324, 0.2021, 0.3165, 0.2992, 0.3096, 0.3568} Mean: 0.2383\n",
1222             "96% Cut Fake Rejections: {90.2, 88.7, 96.8, 97.5, 92.9, 86.3, 83.3, 81.2, 77.1, 69.3} Mean: 86.3%\n",
1223             "\n",
1224             "97% Retention Cut Values: {0.0270, 0.0225, 0.1590, 0.1765, 0.2111, 0.1570, 0.3039, 0.2869, 0.2999, 0.2103} Mean: 0.1854\n",
1225             "97% Cut Fake Rejections: {85.2, 83.8, 95.5, 96.0, 92.1, 83.3, 82.6, 80.4, 76.2, 43.0} Mean: 81.8%\n",
1226             "\n",
1227             "98% Retention Cut Values: {0.0168, 0.0160, 0.0687, 0.0392, 0.1494, 0.1121, 0.2684, 0.2642, 0.2935, 0.1370} Mean: 0.1365\n",
1228             "98% Cut Fake Rejections: {79.1, 79.4, 92.1, 88.6, 89.4, 79.3, 80.6, 78.7, 75.5, 26.5} Mean: 76.9%\n",
1229             "\n",
1230             "99% Retention Cut Values: {0.0091, 0.0075, 0.0350, 0.0213, 0.0435, 0.0676, 0.1957, 0.1649, 0.1080, 0.1046} Mean: 0.0757\n",
1231             "99% Cut Fake Rejections: {69.6, 67.9, 87.1, 82.3, 77.4, 73.2, 76.2, 70.3, 50.1, 19.1} Mean: 67.3%\n",
1232             "\n",
1233             "99.5% Retention Cut Values: {0.0059, 0.0021, 0.0151, 0.0116, 0.0222, 0.0309, 0.1230, 0.1205, 0.0779, 0.0930} Mean: 0.0502\n",
1234             "99.5% Cut Fake Rejections: {62.0, 47.0, 76.9, 74.2, 69.0, 62.6, 69.7, 64.8, 43.9, 16.5} Mean: 58.7%\n"
1235           ]
1236         }
1237       ],
1238       "source": [
1239         "import numpy as np\n",
1240         "from matplotlib import pyplot as plt\n",
1241         "from matplotlib.colors import LogNorm\n",
1242         "import torch\n",
1243         "\n",
1244         "# Ensure input_features_tensor is on the right device\n",
1245         "input_features_tensor = input_features_tensor.to(device)\n",
1246         "\n",
1247         "# Get model predictions\n",
1248         "with torch.no_grad():\n",
1249         "    model.eval()\n",
1250         "    outputs = model(input_features_tensor)\n",
1251         "    predictions = outputs.cpu().numpy()  # Shape will be [n_samples, 3]\n",
1252         "\n",
1253         "# Get track information\n",
1254         "t3_pt = np.concatenate(branches['t3_radius']) * 2 * (2.99792458e-3 * 3.8) / 2\n",
1255         "\n",
1256         "def plot_for_pt_bin(pt_min, pt_max, percentiles, eta_bin_edges, t3_pt, predictions, t3_sim_vxy, eta_list):\n",
1257         "    \"\"\"\n",
1258         "    Calculate and plot cut values for specified percentiles in a given pt bin, separately for prompt and displaced tracks\n",
1259         "    \"\"\"\n",
1260         "    # Filter data based on pt bin\n",
1261         "    pt_mask = (t3_pt > pt_min) & (t3_pt <= pt_max)\n",
1262         "    \n",
1263         "    # Get absolute eta values for all tracks in pt bin\n",
1264         "    abs_eta = np.abs(eta_list[0][pt_mask])\n",
1265         "    \n",
1266         "    # Get predictions for all tracks in pt bin\n",
1267         "    pred_filtered = predictions[pt_mask]\n",
1268         "    \n",
1269         "    # Get track types using pMatched and t3_sim_vxy\n",
1270         "    matched = (np.concatenate(branches['t3_pMatched']) > 0.95)[pt_mask]\n",
1271         "    fake_tracks = (np.concatenate(branches['t3_pMatched']) < 0.75)[pt_mask]\n",
1272         "    true_displaced = (t3_sim_vxy[pt_mask] > 0.1) & matched\n",
1273         "    true_prompt = ~(t3_sim_vxy[pt_mask] > 0.1) & matched\n",
1274         "    \n",
1275         "    # Separate plots for prompt and displaced tracks\n",
1276         "    for track_type, true_mask, pred_idx, title_suffix in [\n",
1277         "        (\"Prompt\", true_prompt, 1, \"Prompt Real Tracks\"),\n",
1278         "        (\"Displaced\", true_displaced, 2, \"Displaced Real Tracks\")\n",
1279         "    ]:\n",
1280         "        # Dictionaries to store values\n",
1281         "        cut_values = {p: [] for p in percentiles}\n",
1282         "        fake_rejections = {p: [] for p in percentiles}\n",
1283         "        \n",
1284         "        # Get probabilities for this class\n",
1285         "        probs = pred_filtered[:, pred_idx]\n",
1286         "        \n",
1287         "        # Create two side-by-side plots\n",
1288         "        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))\n",
1289         "        \n",
1290         "        # Plot probability distribution (only for true tracks of this type)\n",
1291         "        h = ax1.hist2d(abs_eta[true_mask], \n",
1292         "                      probs[true_mask], \n",
1293         "                      bins=[eta_bin_edges, 50], \n",
1294         "                      norm=LogNorm())\n",
1295         "        plt.colorbar(h[3], ax=ax1, label='Counts')\n",
1296         "        \n",
1297         "        # For each eta bin\n",
1298         "        bin_centers = []\n",
1299         "        for i in range(len(eta_bin_edges) - 1):\n",
1300         "            eta_min, eta_max = eta_bin_edges[i], eta_bin_edges[i+1]\n",
1301         "            bin_center = (eta_min + eta_max) / 2\n",
1302         "            bin_centers.append(bin_center)\n",
1303         "            \n",
1304         "            # Get tracks in this eta bin\n",
1305         "            eta_mask = (abs_eta >= eta_min) & (abs_eta < eta_max)\n",
1306         "            \n",
1307         "            # True tracks of this type in this bin\n",
1308         "            true_type_mask = eta_mask & true_mask\n",
1309         "            # Fake tracks in this bin\n",
1310         "            fake_mask = eta_mask & fake_tracks\n",
1311         "            \n",
1312         "            print(f\"Eta bin {eta_min:.2f}-{eta_max:.2f}: {np.sum(fake_mask)} fakes, {np.sum(true_type_mask)} true {track_type}\")\n",
1313         "            \n",
1314         "            if np.sum(true_type_mask) > 0:  # If we have true tracks in this bin\n",
1315         "                for percentile in percentiles:\n",
1316         "                    # Calculate cut value to keep desired percentage of true tracks\n",
1317         "                    cut_value = np.percentile(probs[true_type_mask], 100 - percentile)\n",
1318         "                    cut_values[percentile].append(cut_value)\n",
1319         "                    \n",
1320         "                    # Calculate fake rejection for this cut\n",
1321         "                    if np.sum(fake_mask) > 0:\n",
1322         "                        fake_rej = 100 * np.mean(probs[fake_mask] < cut_value)\n",
1323         "                        fake_rejections[percentile].append(fake_rej)\n",
1324         "                    else:\n",
1325         "                        fake_rejections[percentile].append(np.nan)\n",
1326         "            else:\n",
1327         "                for percentile in percentiles:\n",
1328         "                    cut_values[percentile].append(np.nan)\n",
1329         "                    fake_rejections[percentile].append(np.nan)\n",
1330         "        \n",
1331         "        # Plot cut values and fake rejections\n",
1332         "        colors = plt.cm.rainbow(np.linspace(0, 1, len(percentiles)))\n",
1333         "        bin_centers = np.array(bin_centers)\n",
1334         "        \n",
1335         "        for (percentile, color) in zip(percentiles, colors):\n",
1336         "            values = np.array(cut_values[percentile])\n",
1337         "            mask = ~np.isnan(values)\n",
1338         "            if np.any(mask):\n",
1339         "                # Plot cut values\n",
1340         "                ax1.plot(bin_centers[mask], values[mask], '-', color=color, marker='o',\n",
1341         "                        label=f'{percentile}% Retention Cut')\n",
1342         "                # Plot fake rejections\n",
1343         "                rej_values = np.array(fake_rejections[percentile])\n",
1344         "                ax2.plot(bin_centers[mask], rej_values[mask], '-', color=color, marker='o',\n",
1345         "                        label=f'{percentile}% Cut')\n",
1346         "        \n",
1347         "        # Set plot labels and titles\n",
1348         "        ax1.set_xlabel(\"Absolute Eta\")\n",
1349         "        ax1.set_ylabel(f\"DNN {track_type} Probability\")\n",
1350         "        ax1.set_title(f\"DNN Score vs Eta ({title_suffix})\\npt: {pt_min:.1f} to {pt_max:.1f} GeV\")\n",
1351         "        ax1.legend()\n",
1352         "        ax1.grid(True, alpha=0.3)\n",
1353         "        \n",
1354         "        ax2.set_xlabel(\"Absolute Eta\")\n",
1355         "        ax2.set_ylabel(\"Fake Rejection (%)\")\n",
1356         "        ax2.set_title(f\"Fake Rejection vs Eta\\npt: {pt_min:.1f} to {pt_max:.1f} GeV\")\n",
1357         "        ax2.legend()\n",
1358         "        ax2.grid(True, alpha=0.3)\n",
1359         "        ax2.set_ylim(0, 100)\n",
1360         "        \n",
1361         "        plt.tight_layout()\n",
1362         "        plt.show()\n",
1363         "        \n",
1364         "        # Print statistics\n",
1365         "        print(f\"\\n{track_type} tracks, pt: {pt_min:.1f} to {pt_max:.1f} GeV\")\n",
1366         "        print(f\"Number of true {track_type.lower()} tracks: {np.sum(true_mask)}\")\n",
1367         "        print(f\"Number of fake tracks in pt bin: {np.sum(fake_tracks)}\")\n",
1368         "        \n",
1369         "        for percentile in percentiles:\n",
1370         "            print(f\"\\n{percentile}% Retention Cut Values:\",\n",
1371         "                  '{' + ', '.join(f\"{x:.4f}\" if not np.isnan(x) else 'nan' for x in cut_values[percentile]) + '}',\n",
1372         "                  f\"Mean: {np.round(np.nanmean(cut_values[percentile]), 4)}\")\n",
1373         "            print(f\"{percentile}% Cut Fake Rejections:\",\n",
1374         "                  '{' + ', '.join(f\"{x:.1f}\" if not np.isnan(x) else 'nan' for x in fake_rejections[percentile]) + '}',\n",
1375         "                  f\"Mean: {np.round(np.nanmean(fake_rejections[percentile]), 1)}%\")\n",
1376         "\n",
1377         "def analyze_pt_bins(pt_bins, percentiles, eta_bin_edges, t3_pt, predictions, t3_sim_vxy, eta_list):\n",
1378         "    \"\"\"\n",
1379         "    Analyze and plot for multiple pt bins and percentiles\n",
1380         "    \"\"\"\n",
1381         "    for i in range(len(pt_bins) - 1):\n",
1382         "        plot_for_pt_bin(pt_bins[i], pt_bins[i + 1], percentiles, eta_bin_edges,\n",
1383         "                       t3_pt, predictions, t3_sim_vxy, eta_list)\n",
1384         "\n",
1385         "# Run the analysis with same parameters as before\n",
1386         "percentiles = [80, 90, 93, 96, 97, 98, 99, 99.5]\n",
1387         "pt_bins = [0, 5, np.inf]\n",
1388         "eta_bin_edges = np.arange(0, 2.75, 0.25)\n",
1389         "\n",
1390         "analyze_pt_bins(\n",
1391         "    pt_bins=pt_bins,\n",
1392         "    percentiles=percentiles,\n",
1393         "    eta_bin_edges=eta_bin_edges,\n",
1394         "    t3_pt=t3_pt,\n",
1395         "    predictions=predictions,\n",
1396         "    t3_sim_vxy=np.concatenate(branches['t3_sim_vxy']),\n",
1397         "    eta_list=eta_list\n",
1398         ")"
1399       ]
1400     },
1401     {
1402       "cell_type": "code",
1403       "execution_count": null,
1404       "metadata": {},
1405       "outputs": [],
1406       "source": []
1407     }
1408   ],
1409   "metadata": {
1410     "kernelspec": {
1411       "display_name": "analysisenv",
1412       "language": "python",
1413       "name": "python3"
1414     },
1415     "language_info": {
1416       "codemirror_mode": {
1417         "name": "ipython",
1418         "version": 3
1419       },
1420       "file_extension": ".py",
1421       "mimetype": "text/x-python",
1422       "name": "python",
1423       "nbconvert_exporter": "python",
1424       "pygments_lexer": "ipython3",
1425       "version": "3.11.7"
1426     }
1427   },
1428   "nbformat": 4,
1429   "nbformat_minor": 2
1430 }