File indexing completed on 2024-04-06 12:27:33
0001 import os
0002 import sys
0003 import random
0004 os.environ["KERAS_BACKEND"] = "tensorflow"
0005
0006 import glob
0007 try:
0008 if not ("CUDA_VISIBLE_DEVICES" in os.environ):
0009 print("importing setGPU")
0010 import setGPU
0011 except:
0012 print("Could not import setGPU, please make sure you configure CUDA_VISIBLE_DEVICES manually")
0013 pass
0014
0015 try:
0016 from comet_ml import Experiment
0017 comet_enabled = True
0018 except ImportError as e:
0019 print("could not import comet, online dashboard disabled")
0020 comet_enabled = False
0021
0022 import pickle
0023 import matplotlib.pyplot as plt
0024 import numpy as np
0025 from sklearn.metrics import confusion_matrix, accuracy_score
0026 import pandas
0027 import time
0028 import itertools
0029 import io
0030 import tensorflow as tf
0031
0032
0033
0034
0035 from numpy.lib.recfunctions import append_fields
0036
0037 elem_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
0038 class_labels = [0, 1, 2, 11, 13, 22, 130, 211]
0039
0040 num_max_elems = 5000
0041
0042 mult_classification_loss = 1e3
0043 mult_charge_loss = 1.0
0044 mult_energy_loss = 10.0
0045 mult_phi_loss = 10.0
0046 mult_eta_loss = 10.0
0047 mult_total_loss = 1e3
0048
0049 def split_indices_to_bins(cmul, nbins, bin_size):
0050 bin_idx = tf.argmax(cmul, axis=-1)
0051 bins_split = tf.reshape(tf.argsort(bin_idx), (nbins, bin_size))
0052 return bins_split
0053
0054 def pairwise_dist(A, B):
0055 na = tf.reduce_sum(tf.square(A), -1)
0056 nb = tf.reduce_sum(tf.square(B), -1)
0057
0058
0059 na = tf.expand_dims(na, -1)
0060 nb = tf.expand_dims(nb, -2)
0061
0062
0063 D = tf.sqrt(tf.maximum(na - 2*tf.matmul(A, B, False, True) + nb, 1e-6))
0064 return D
0065
0066 """
0067 sp_a: (nbatch, nelem, nelem) sparse distance matrices
0068 b: (nbatch, nelem, ncol) dense per-element feature matrices
0069 """
0070 def sparse_dense_matmult_batch(sp_a, b):
0071
0072 num_batches = tf.shape(b)[0]
0073 def map_function(x):
0074 i, dense_slice = x[0], x[1]
0075 num_points = tf.shape(b)[1]
0076
0077 sparse_slice = tf.sparse.reshape(tf.sparse.slice(
0078 sp_a, [i, 0, 0], [1, num_points, num_points]),
0079 [num_points, num_points])
0080 mult_slice = tf.sparse.sparse_dense_matmul(sparse_slice, dense_slice)
0081 return mult_slice
0082
0083 elems = (tf.range(0, tf.cast(num_batches, tf.int64), delta=1, dtype=tf.int64), b)
0084 ret = tf.map_fn(map_function, elems, fn_output_signature=tf.float32, back_prop=True)
0085 return ret
0086
0087
0088 def summarize_dataset(dataset):
0089 yclasses = []
0090 nev = 0.0
0091 ntot = 0.0
0092 sizes = []
0093
0094 for X, y, w in dataset:
0095 yclasses += [y[:, 0]]
0096 nev += 1
0097 ntot += len(y)
0098 sizes += [len(y)]
0099
0100 yclasses = np.concatenate(yclasses)
0101 values, counts= np.unique(yclasses, return_counts=True)
0102 print("nev={}".format(nev))
0103 print("sizes={}".format(np.percentile(sizes, [25, 50, 95, 99])))
0104 for v, c in zip(values, counts):
0105 print("label={} count={} frac={:.6f}".format(class_labels[int(v)], c, c/ntot))
0106
0107
0108 beta = 0.9999
0109 def compute_weights_classbalanced(X, y, w):
0110 wn = (1.0 - beta)/(1.0 - tf.pow(beta, w))
0111 wn /= tf.reduce_sum(wn)
0112 return X, y, wn
0113
0114
0115 def compute_weights_uniform(X, y, w):
0116 wn = tf.ones_like(w)
0117 wn /= tf.reduce_sum(wn)
0118 return X, y, wn
0119
0120
0121 def compute_weights_inverse(X, y, w):
0122 wn = 1.0/tf.sqrt(w)
0123 wn /= tf.reduce_sum(wn)
0124 return X, y, wn
0125
0126 weight_schemes = {
0127 "uniform": compute_weights_uniform,
0128 "inverse": compute_weights_inverse,
0129 "classbalanced": compute_weights_classbalanced,
0130 }
0131
0132 def load_one_file(fn):
0133 Xs = []
0134 ys = []
0135 ys_cand = []
0136 dms = []
0137
0138 data = pickle.load(open(fn, "rb"), encoding='iso-8859-1')
0139 for event in data:
0140 Xelem = event["Xelem"]
0141 ygen = event["ygen"]
0142 ycand = event["ycand"]
0143
0144
0145 msk_ps = (Xelem["typ"] == 2) | (Xelem["typ"] == 3)
0146
0147 Xelem = Xelem[~msk_ps]
0148 ygen = ygen[~msk_ps]
0149 ycand = ycand[~msk_ps]
0150
0151 Xelem = append_fields(Xelem, "typ_idx", np.array([elem_labels.index(int(i)) for i in Xelem["typ"]], dtype=np.float32))
0152 ygen = append_fields(ygen, "typ_idx", np.array([class_labels.index(abs(int(i))) for i in ygen["typ"]], dtype=np.float32))
0153 ycand = append_fields(ycand, "typ_idx", np.array([class_labels.index(abs(int(i))) for i in ycand["typ"]], dtype=np.float32))
0154
0155 Xelem_flat = np.stack([Xelem[k].view(np.float32).data for k in [
0156 'typ_idx',
0157 'pt', 'eta', 'phi', 'e',
0158 'layer', 'depth', 'charge', 'trajpoint',
0159 'eta_ecal', 'phi_ecal', 'eta_hcal', 'phi_hcal',
0160 'muon_dt_hits', 'muon_csc_hits']], axis=-1
0161 )
0162 ygen_flat = np.stack([ygen[k].view(np.float32).data for k in [
0163 'typ_idx',
0164 'eta', 'phi', 'e', 'charge',
0165 ]], axis=-1
0166 )
0167 ycand_flat = np.stack([ycand[k].view(np.float32).data for k in [
0168 'typ_idx',
0169 'eta', 'phi', 'e', 'charge',
0170 ]], axis=-1
0171 )
0172
0173
0174 Xelem_flat[np.isnan(Xelem_flat)] = 0
0175 Xelem_flat[np.abs(Xelem_flat) > 1e4] = 0
0176 ygen_flat[np.isnan(ygen_flat)] = 0
0177 ygen_flat[np.abs(ygen_flat) > 1e4] = 0
0178 ycand_flat[np.isnan(ycand_flat)] = 0
0179 ycand_flat[np.abs(ycand_flat) > 1e4] = 0
0180
0181 Xs += [Xelem_flat[:num_max_elems]]
0182 ys += [ygen_flat[:num_max_elems]]
0183 ys_cand += [ycand_flat[:num_max_elems]]
0184
0185 print("created {} blocks, max size {}".format(len(Xs), max([len(X) for X in Xs])))
0186 return Xs, ys, ys_cand
0187
0188
0189 class InputEncoding(tf.keras.layers.Layer):
0190 def __init__(self, num_input_classes):
0191 super(InputEncoding, self).__init__()
0192 self.num_input_classes = num_input_classes
0193
0194 """
0195 X: [Nbatch, Nelem, Nfeat] array of all the input detector element feature data
0196 """
0197 def call(self, X):
0198
0199
0200 Xid = tf.cast(tf.one_hot(tf.cast(X[:, :, 0], tf.int32), self.num_input_classes), dtype=tf.float32)
0201
0202
0203 Xprop = X[:, :, 1:]
0204 return tf.concat([Xid, Xprop], axis=-1)
0205
0206
0207
0208 class GHConv(tf.keras.layers.Layer):
0209 def __init__(self, *args, **kwargs):
0210 self.activation = kwargs.pop("activation")
0211 self.hidden_dim = args[0]
0212
0213 super(GHConv, self).__init__(*args, **kwargs)
0214
0215 def build(self, input_shape):
0216 self.W_t = self.add_weight(shape=(self.hidden_dim, self.hidden_dim), name="w_t", initializer="random_normal")
0217 self.b_t = self.add_weight(shape=(self.hidden_dim, ), name="b_t", initializer="random_normal")
0218 self.W_h = self.add_weight(shape=(self.hidden_dim, self.hidden_dim), name="w_h", initializer="random_normal")
0219 self.theta = self.add_weight(shape=(self.hidden_dim, self.hidden_dim), name="theta", initializer="random_normal")
0220
0221 def call(self, inputs):
0222 x, adj = inputs
0223
0224
0225 in_degrees = tf.sparse.reduce_sum(adj, axis=-1)
0226 in_degrees = tf.reshape(in_degrees, (tf.shape(x)[0], tf.shape(x)[1]))
0227
0228
0229 norm = tf.expand_dims(tf.pow(in_degrees + 1e-6, -0.5), -1)
0230
0231 f_hom = tf.linalg.matmul(x, self.theta)
0232 f_hom = sparse_dense_matmult_batch(adj, f_hom*norm)*norm
0233
0234 f_het = tf.linalg.matmul(x, self.W_h)
0235 gate = tf.nn.sigmoid(tf.linalg.matmul(x, self.W_t) + self.b_t)
0236
0237 out = gate*f_hom + (1-gate)*f_het
0238 return self.activation(out)
0239
0240 class GHConvDense(tf.keras.layers.Layer):
0241 def __init__(self, *args, **kwargs):
0242 self.activation = kwargs.pop("activation")
0243 self.hidden_dim = args[0]
0244 super(GHConvDense, self).__init__(*args, **kwargs)
0245
0246 def build(self, input_shape):
0247 self.W_t = self.add_weight(shape=(self.hidden_dim, self.hidden_dim), name="w_t", initializer="random_normal")
0248 self.b_t = self.add_weight(shape=(self.hidden_dim, ), name="b_t", initializer="random_normal")
0249 self.W_h = self.add_weight(shape=(self.hidden_dim, self.hidden_dim), name="w_h", initializer="random_normal")
0250 self.theta = self.add_weight(shape=(self.hidden_dim, self.hidden_dim), name="theta", initializer="random_normal")
0251
0252 def call(self, inputs):
0253 x, adj = inputs
0254
0255
0256 in_degrees = tf.reduce_sum(adj, axis=-1)
0257 in_degrees = tf.reshape(in_degrees, (tf.shape(x)[0], tf.shape(x)[1]))
0258
0259
0260 norm = tf.expand_dims(tf.pow(in_degrees + 1e-6, -0.5), -1)
0261
0262 f_hom = tf.linalg.matmul(x, self.theta)
0263 f_hom = tf.linalg.matmul(adj, f_hom*norm)*norm
0264
0265 f_het = tf.linalg.matmul(x, self.W_h)
0266 gate = tf.nn.sigmoid(tf.linalg.matmul(x, self.W_t) + self.b_t)
0267
0268 out = gate*f_hom + (1-gate)*f_het
0269 return self.activation(out)
0270
0271 class DenseDistance(tf.keras.layers.Layer):
0272 def __init__(self, dist_mult=0.1, **kwargs):
0273 super(DenseDistance, self).__init__(**kwargs)
0274 self.dist_mult = dist_mult
0275
0276 def call(self, inputs, training=True):
0277 dm = pairwise_dist(inputs, inputs)
0278 dm = tf.exp(-self.dist_mult*dm)
0279 return dm
0280
0281 class SparseHashedNNDistance(tf.keras.layers.Layer):
0282 def __init__(self, max_num_bins=200, bin_size=500, num_neighbors=5, dist_mult=0.1, cosine_dist=False, **kwargs):
0283 super(SparseHashedNNDistance, self).__init__(**kwargs)
0284 self.num_neighbors = num_neighbors
0285 self.dist_mult = dist_mult
0286
0287 self.cosine_dist = cosine_dist
0288
0289
0290
0291 self.max_num_bins = max_num_bins
0292
0293
0294
0295 self.bin_size = bin_size
0296
0297 def build(self, input_shape):
0298
0299
0300
0301 self.codebook_random_rotations = self.add_weight(
0302 shape=(input_shape[-1], self.max_num_bins//2), initializer="random_normal", trainable=False, name="lsh_projections"
0303 )
0304
0305 def call(self, inputs, training=True):
0306
0307
0308 point_embedding = inputs
0309
0310 n_batches = tf.shape(point_embedding)[0]
0311 n_points = tf.shape(point_embedding)[1]
0312
0313
0314
0315
0316 indices_all = []
0317 values_all = []
0318
0319 def func(args):
0320 ibatch, points_batch = args[0], args[1]
0321 dm = self.construct_sparse_dm_batch(points_batch)
0322 inds = tf.concat([tf.expand_dims(tf.cast(ibatch, tf.int64)*tf.ones(tf.shape(dm.indices)[0], dtype=tf.int64), -1), dm.indices], axis=-1)
0323 vals = dm.values
0324 return inds, vals
0325
0326 elems = (tf.range(0, tf.cast(n_batches, tf.int64), delta=1, dtype=tf.int64), point_embedding)
0327 ret = tf.map_fn(func, elems, fn_output_signature=(tf.int64, tf.float32), parallel_iterations=1)
0328 shp = tf.shape(ret[0])
0329
0330 dms = tf.SparseTensor(
0331 tf.reshape(ret[0], (shp[0]*shp[1], shp[2])),
0332 tf.reshape(ret[1], (shp[0]*shp[1],)),
0333 (n_batches, n_points, n_points)
0334 )
0335
0336 return tf.sparse.reorder(dms)
0337
0338 def subpoints_to_sparse_matrix(self, n_points, subindices, subpoints):
0339
0340
0341 if self.cosine_dist:
0342 normed = tf.nn.l2_normalize(subpoints, axis=-1)
0343 dm = tf.linalg.matmul(subpoints, subpoints, transpose_b=True)
0344 else:
0345 dm = pairwise_dist(subpoints, subpoints)
0346 dm = tf.exp(-self.dist_mult*dm)
0347
0348 dmshape = tf.shape(dm)
0349 nbins = dmshape[0]
0350 nelems = dmshape[1]
0351
0352
0353 top_k = tf.nn.top_k(dm, k=self.num_neighbors)
0354 top_k_vals = tf.reshape(top_k.values, (nbins*nelems, self.num_neighbors))
0355
0356 indices_gathered = tf.vectorized_map(
0357 lambda i: tf.gather_nd(subindices, top_k.indices[:, :, i:i+1], batch_dims=1),
0358 tf.range(self.num_neighbors, dtype=tf.int64))
0359
0360 indices_gathered = tf.transpose(indices_gathered, [1,2,0])
0361
0362
0363
0364
0365
0366
0367
0368
0369
0370
0371
0372
0373
0374
0375
0376
0377
0378
0379
0380
0381
0382
0383 sp_sum = tf.sparse.SparseTensor(indices=tf.zeros((0,2), dtype=tf.int64), values=tf.zeros(0, tf.float32), dense_shape=(n_points, n_points))
0384 for i in range(self.num_neighbors):
0385 dst_ind = indices_gathered[:, :, i]
0386 dst_ind = tf.reshape(dst_ind, (nbins*nelems, ))
0387 src_ind = tf.reshape(tf.stack(subindices), (nbins*nelems, ))
0388 src_dst_inds = tf.cast(tf.transpose(tf.stack([src_ind, dst_ind])), dtype=tf.int64)
0389 sp_sum = tf.sparse.add(
0390 sp_sum,
0391 tf.sparse.reorder(tf.sparse.SparseTensor(src_dst_inds, top_k_vals[:, i], (n_points, n_points)))
0392 )
0393 spt_this = tf.sparse.reorder(sp_sum)
0394
0395 return spt_this
0396
0397 def construct_sparse_dm_batch(self, points):
0398
0399
0400 n_points = tf.shape(points)[0]
0401 n_features = tf.shape(points)[1]
0402
0403
0404
0405 n_bins = tf.math.floordiv(n_points, self.bin_size)
0406
0407
0408
0409 mul = tf.linalg.matmul(points, self.codebook_random_rotations[:, :n_bins//2])
0410
0411
0412 cmul = tf.concat([mul, -mul], axis=-1)
0413
0414
0415
0416 bins_split = split_indices_to_bins(cmul, n_bins, self.bin_size)
0417
0418
0419 parts = tf.gather(points, bins_split)
0420
0421
0422
0423 sparse_distance_matrix = self.subpoints_to_sparse_matrix(n_points, bins_split, parts)
0424
0425 return sparse_distance_matrix
0426
0427 class EncoderDecoderGNN(tf.keras.layers.Layer):
0428 def __init__(self, encoders, decoders, dropout, activation, conv, **kwargs):
0429 super(EncoderDecoderGNN, self).__init__(**kwargs)
0430 name = kwargs.get("name")
0431
0432
0433 self.encoders = encoders
0434 self.decoders = decoders
0435
0436 self.encoding_layers = []
0437 for ilayer, nunits in enumerate(encoders):
0438 self.encoding_layers.append(
0439 tf.keras.layers.Dense(nunits, activation=activation, name="encoding_{}_{}".format(name, ilayer)))
0440 if dropout > 0.0:
0441 self.encoding_layers.append(tf.keras.layers.Dropout(dropout))
0442
0443 self.conv = conv
0444
0445 self.decoding_layers = []
0446 for ilayer, nunits in enumerate(decoders):
0447 self.decoding_layers.append(
0448 tf.keras.layers.Dense(nunits, activation=activation, name="decoding_{}_{}".format(name, ilayer)))
0449 if dropout > 0.0:
0450 self.decoding_layers.append(tf.keras.layers.Dropout(dropout))
0451
0452 def call(self, inputs, distance_matrix, training=True):
0453 x = inputs
0454
0455 for layer in self.encoding_layers:
0456 x = layer(x)
0457
0458 for convlayer in self.conv:
0459 x = convlayer([x, distance_matrix])
0460
0461 for layer in self.decoding_layers:
0462 x = layer(x)
0463
0464 return x
0465
0466 class AddSparse(tf.keras.layers.Layer):
0467 def __init__(self, **kwargs):
0468 super(AddSparse, self).__init__(**kwargs)
0469
0470 def call(self, matrices):
0471 ret = matrices[0]
0472 for mat in matrices[1:]:
0473 ret = tf.sparse.add(ret, mat)
0474 return ret
0475
0476
0477 class PFNet(tf.keras.Model):
0478 def __init__(self,
0479 activation=tf.nn.selu,
0480 hidden_dim_id=256,
0481 hidden_dim_reg=256,
0482 distance_dim=256,
0483 convlayer="ghconv",
0484 dropout=0.1,
0485 bin_size=10,
0486 num_convs_id=1,
0487 num_convs_reg=1,
0488 num_hidden_id_enc=1,
0489 num_hidden_id_dec=1,
0490 num_hidden_reg_enc=1,
0491 num_hidden_reg_dec=1,
0492 num_neighbors=5,
0493 dist_mult=0.1,
0494 cosine_dist=False):
0495
0496 super(PFNet, self).__init__()
0497 self.activation = activation
0498 self.num_dists = 1
0499
0500 encoding_id = []
0501 decoding_id = []
0502 encoding_reg = []
0503 decoding_reg = []
0504
0505
0506 for ihidden in range(num_hidden_id_enc):
0507 encoding_id.append(hidden_dim_id)
0508
0509 for ihidden in range(num_hidden_id_dec):
0510 decoding_id.append(hidden_dim_id)
0511
0512 for ihidden in range(num_hidden_reg_enc):
0513 encoding_reg.append(hidden_dim_reg)
0514
0515 for ihidden in range(num_hidden_reg_dec):
0516 decoding_reg.append(hidden_dim_reg)
0517
0518 self.enc = InputEncoding(len(elem_labels))
0519 self.layer_embedding = tf.keras.layers.Dense(distance_dim, name="embedding_attention")
0520
0521 self.embedding_dropout = None
0522 if dropout > 0.0:
0523 self.embedding_dropout = tf.keras.layers.Dropout(dropout)
0524
0525 self.dists = []
0526 for idist in range(self.num_dists):
0527 self.dists.append(SparseHashedNNDistance(bin_size=bin_size, num_neighbors=num_neighbors, dist_mult=dist_mult, cosine_dist=cosine_dist))
0528 self.addsparse = AddSparse()
0529
0530
0531 convs_id = []
0532 convs_reg = []
0533
0534 for iconv in range(num_convs_id):
0535 convs_id.append(GHConv(26 if len(encoding_id)==0 else hidden_dim_id, activation=activation, name="conv_id{}".format(iconv)))
0536 for iconv in range(num_convs_reg):
0537 convs_reg.append(GHConv(35 if len(encoding_reg)==0 else hidden_dim_reg, activation=activation, name="conv_reg{}".format(iconv)))
0538
0539 self.gnn_id = EncoderDecoderGNN(encoding_id, decoding_id, dropout, activation, convs_id, name="gnn_id")
0540 self.layer_id = tf.keras.layers.Dense(len(class_labels), activation="linear", name="out_id")
0541 self.layer_charge = tf.keras.layers.Dense(1, activation="linear", name="out_charge")
0542
0543 self.gnn_reg = EncoderDecoderGNN(encoding_reg, decoding_reg, dropout, activation, convs_reg, name="gnn_reg")
0544 self.layer_momentum = tf.keras.layers.Dense(3, activation="linear", name="out_momentum")
0545
0546 def create_model(self, num_max_elems, training=True):
0547 inputs = tf.keras.Input(shape=(num_max_elems,15,))
0548 return tf.keras.Model(inputs=[inputs], outputs=self.call(inputs, training), name="MLPFNet")
0549
0550 def call(self, inputs, training=True):
0551 X = tf.cast(inputs, tf.float32)
0552 msk_input = tf.expand_dims(tf.cast(X[:, :, 0] != 0, tf.float32), -1)
0553
0554 enc = self.enc(inputs)
0555
0556
0557 embedding_attention = self.layer_embedding(enc)
0558 if self.embedding_dropout:
0559 embedding_attention = self.embedding_dropout(embedding_attention, training)
0560
0561
0562 dms = [dist(embedding_attention, training) for dist in self.dists]
0563 dm = self.addsparse(dms)
0564
0565
0566 x_id = self.gnn_id(enc, dm, training)
0567 to_decode = tf.concat([enc, x_id], axis=-1)
0568 out_id_logits = self.layer_id(to_decode)
0569 out_charge = self.layer_charge(to_decode)
0570
0571
0572 x_reg = self.gnn_reg(tf.concat([enc, out_id_logits, out_charge], axis=-1), dm, training)
0573 to_decode = tf.concat([enc, x_reg], axis=-1)
0574 pred_corr = self.layer_momentum(to_decode)
0575
0576
0577 probabilistic_mask_good = 1.0 - tf.keras.activations.softmax(out_id_logits)[:, :, 0]
0578
0579 out_momentum_eta = X[:, :, 2] + pred_corr[:, :, 0]
0580 out_momentum_phi = X[:, :, 3] + pred_corr[:, :, 1]
0581 out_momentum_E = X[:, :, 4] + pred_corr[:, :, 2]
0582
0583 out_momentum = tf.stack([
0584 out_momentum_eta * probabilistic_mask_good,
0585 out_momentum_phi * probabilistic_mask_good,
0586 out_momentum_E * probabilistic_mask_good,
0587 ], axis=-1)
0588
0589 ret = tf.concat([out_id_logits, out_momentum, out_charge], axis=-1)*msk_input
0590 return ret
0591
0592 def set_trainable_classification(self):
0593 self.gnn_reg.trainable = False
0594 self.layer_momentum.trainable = False
0595
0596 def set_trainable_regression(self):
0597 for layer in self.layers:
0598 layer.trainable = False
0599 self.gnn_reg.trainable = True
0600 self.layer_momentum.trainable = True
0601
0602
0603 class PFNetDummy(tf.keras.Model):
0604 def __init__(self, **kwargs):
0605 super(PFNetDummy, self).__init__()
0606 self.enc = InputEncoding(len(elem_labels))
0607
0608 self.flatten = tf.keras.layers.Flatten()
0609 self.layer_hidden0 = tf.keras.layers.Dense(32, activation="elu")
0610 self.layer_hidden1 = tf.keras.layers.Dense(64, activation="elu")
0611 self.layer_hidden2 = tf.keras.layers.Dense(128, activation="elu")
0612 self.layer_hidden3 = tf.keras.layers.Dense(256, activation="elu")
0613
0614 self.layer_id = tf.keras.layers.Dense(len(class_labels), activation="linear", name="out_id")
0615 self.layer_charge = tf.keras.layers.Dense(1, activation="linear", name="out_charge")
0616 self.layer_momentum = tf.keras.layers.Dense(3, activation="linear", name="out_momentum")
0617
0618 def call(self, inputs, training=True):
0619 X = tf.cast(inputs, tf.float32)
0620 msk_input = tf.expand_dims(tf.cast(X[:, :, 0] != 0, tf.float32), -1)
0621 enc = self.enc(inputs)
0622
0623 h = self.layer_hidden0(flat)
0624 h = self.layer_hidden1(h)
0625 h = self.layer_hidden2(h)
0626 h = self.layer_hidden3(h)
0627
0628 out_id_logits = self.layer_id(h)
0629 out_charge = self.layer_charge(h)
0630 pred_corr = self.layer_momentum(h)
0631
0632
0633 probabilistic_mask_good = 1.0 - tf.keras.activations.softmax(out_id_logits)[:, :, 0]
0634
0635 out_momentum_eta = X[:, :, 2] + pred_corr[:, :, 0]
0636 out_momentum_phi = X[:, :, 3] + pred_corr[:, :, 1]
0637 out_momentum_E = X[:, :, 4] + pred_corr[:, :, 2]
0638
0639 out_momentum = tf.stack([
0640 out_momentum_eta * probabilistic_mask_good,
0641 out_momentum_phi * probabilistic_mask_good,
0642 out_momentum_E * probabilistic_mask_good,
0643 ], axis=-1)
0644
0645 ret = tf.concat([out_id_logits, out_momentum, out_charge], axis=-1)*msk_input
0646 return ret
0647
0648 def set_trainable_classification(self):
0649 self.layer_momentum.trainable = False
0650
0651 def set_trainable_regression(self):
0652 pass
0653
0654 def create_model(self, num_max_elems, training=True):
0655 inputs = tf.keras.Input(shape=(num_max_elems,15,))
0656 return tf.keras.Model(inputs=[inputs], outputs=self.call(inputs, training), name="MLPFNet")
0657
0658 def separate_prediction(y_pred):
0659 N = len(class_labels)
0660 pred_id_logits = y_pred[:, :, :N]
0661 pred_momentum = y_pred[:, :, N:N+3]
0662 pred_charge = y_pred[:, :, N+3:N+4]
0663 return pred_id_logits, pred_charge, pred_momentum
0664
0665 def separate_truth(y_true):
0666 true_id = tf.cast(y_true[:, :, :1], tf.int32)
0667 true_momentum = y_true[:, :, 1:4]
0668 true_charge = y_true[:, :, 4:5]
0669 return true_id, true_charge, true_momentum
0670
0671 def mse_unreduced(true, pred):
0672 return tf.math.pow(true-pred,2)
0673
0674 def msle_unreduced(true, pred):
0675 return tf.math.pow(tf.math.log(tf.math.abs(true) + 1.0) - tf.math.log(tf.math.abs(pred) + 1.0), 2)
0676
0677 def my_loss_cls(y_true, y_pred):
0678 pred_id_logits, pred_charge, _ = separate_prediction(y_pred)
0679 true_id, true_charge, _ = separate_truth(y_true)
0680
0681 true_id_onehot = tf.one_hot(tf.cast(true_id, tf.int32), depth=len(class_labels))
0682
0683 l1 = mult_classification_loss*tf.nn.softmax_cross_entropy_with_logits(true_id_onehot, pred_id_logits)
0684 l3 = mult_charge_loss*mse_unreduced(true_charge, pred_charge)[:, :, 0]
0685
0686 loss = l1 + l3
0687 return mult_total_loss*loss
0688
0689 def my_loss_reg(y_true, y_pred):
0690 _, _, pred_momentum = separate_prediction(y_pred)
0691 _, true_charge, true_momentum = separate_truth(y_true)
0692
0693 l2_0 = mult_eta_loss*mse_unreduced(true_momentum[:, :, 0], pred_momentum[:, :, 0])
0694 l2_1 = mult_phi_loss*mse_unreduced(tf.math.floormod(true_momentum[:, :, 1] - pred_momentum[:, :, 1] + np.pi, 2*np.pi) - np.pi, 0.0)
0695 l2_2 = mult_energy_loss*mse_unreduced(true_momentum[:, :, 2], pred_momentum[:, :, 2])
0696
0697 loss = (l2_0 + l2_1 + l2_2)
0698
0699 return 1e3*loss
0700
0701 def my_loss_full(y_true, y_pred):
0702 pred_id_logits, pred_charge, pred_momentum = separate_prediction(y_pred)
0703 pred_id = tf.cast(tf.argmax(pred_id_logits, axis=-1), tf.int32)
0704 true_id, true_charge, true_momentum = separate_truth(y_true)
0705 true_id_onehot = tf.one_hot(tf.cast(true_id, tf.int32), depth=len(class_labels))
0706
0707 l1 = mult_classification_loss*tf.nn.softmax_cross_entropy_with_logits(true_id_onehot, pred_id_logits)
0708
0709 l2_0 = mult_eta_loss*mse_unreduced(true_momentum[:, :, 0], pred_momentum[:, :, 0])
0710 l2_1 = mult_phi_loss*mse_unreduced(tf.math.floormod(true_momentum[:, :, 1] - pred_momentum[:, :, 1] + np.pi, 2*np.pi) - np.pi, 0.0)
0711 l2_2 = mult_energy_loss*mse_unreduced(true_momentum[:, :, 2], pred_momentum[:, :, 2])
0712
0713 l2 = (l2_0 + l2_1 + l2_2)
0714
0715 l3 = mult_charge_loss*mse_unreduced(true_charge, pred_charge)[:, :, 0]
0716 loss = l1 + l2 + l3
0717
0718 return mult_total_loss*loss
0719
0720
0721 def cls_130(y_true, y_pred):
0722 pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0723 pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0724 true_id, true_charge, true_momentum = separate_truth(y_true)
0725
0726 msk_true = true_id[:, :, 0] == class_labels.index(130)
0727 msk_pos = pred_id == class_labels.index(130)
0728 num_true_pos = tf.reduce_sum(tf.cast(msk_true&msk_pos, tf.float32))
0729 num_true = tf.reduce_sum(tf.cast(msk_true, tf.float32))
0730 return num_true_pos/num_true
0731
0732 def cls_211(y_true, y_pred):
0733 pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0734 pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0735 true_id, true_charge, true_momentum = separate_truth(y_true)
0736
0737 msk_true = true_id[:, :, 0] == class_labels.index(211)
0738 msk_pos = pred_id == class_labels.index(211)
0739 num_true_pos = tf.reduce_sum(tf.cast(msk_true&msk_pos, tf.float32))
0740 num_true = tf.reduce_sum(tf.cast(msk_true, tf.float32))
0741
0742 return num_true_pos/num_true
0743
0744 def cls_22(y_true, y_pred):
0745 pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0746 pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0747 true_id, true_charge, true_momentum = separate_truth(y_true)
0748
0749 msk_true = true_id[:, :, 0] == class_labels.index(22)
0750 msk_pos = pred_id == class_labels.index(22)
0751 num_true_pos = tf.reduce_sum(tf.cast(msk_true&msk_pos, tf.float32))
0752 num_true = tf.reduce_sum(tf.cast(msk_true, tf.float32))
0753
0754 return num_true_pos/num_true
0755
0756 def cls_11(y_true, y_pred):
0757 pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0758 pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0759 true_id, true_charge, true_momentum = separate_truth(y_true)
0760
0761 msk_true = true_id[:, :, 0] == class_labels.index(11)
0762 msk_pos = pred_id == class_labels.index(11)
0763 num_true_pos = tf.reduce_sum(tf.cast(msk_true&msk_pos, tf.float32))
0764 num_true = tf.reduce_sum(tf.cast(msk_true, tf.float32))
0765
0766 return num_true_pos/num_true
0767
0768 def cls_13(y_true, y_pred):
0769 pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0770 pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0771 true_id, true_charge, true_momentum = separate_truth(y_true)
0772
0773 msk_true = true_id[:, :, 0] == class_labels.index(13)
0774 msk_pos = pred_id == class_labels.index(13)
0775 num_true_pos = tf.reduce_sum(tf.cast(msk_true&msk_pos, tf.float32))
0776 num_true = tf.reduce_sum(tf.cast(msk_true, tf.float32))
0777
0778 return num_true_pos/num_true
0779
0780 def num_pred(y_true, y_pred):
0781 pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0782 pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0783 true_id, true_charge, true_momentum = separate_truth(y_true)
0784
0785 ntrue = tf.reduce_sum(tf.cast(true_id[:, :, 0]!=0, tf.int32))
0786 npred = tf.reduce_sum(tf.cast(pred_id!=0, tf.int32))
0787 return tf.cast(ntrue - npred, tf.float32)
0788
0789 def accuracy(y_true, y_pred):
0790 pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0791 pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0792 true_id, true_charge, true_momentum = separate_truth(y_true)
0793
0794 is_true = true_id[:, :, 0]!=0
0795 is_same = true_id[:, :, 0] == pred_id
0796
0797 acc = tf.reduce_sum(tf.cast(is_true&is_same, tf.int32)) / tf.reduce_sum(tf.cast(is_true, tf.int32))
0798 return tf.cast(acc, tf.float32)
0799
0800 def eta_resolution(y_true, y_pred):
0801 pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0802 pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0803 true_id, true_charge, true_momentum = separate_truth(y_true)
0804
0805 msk = true_id[:, :, 0]!=0
0806 return tf.reduce_mean(mse_unreduced(true_momentum[msk][:, 0], pred_momentum[msk][:, 0]))
0807
0808 def phi_resolution(y_true, y_pred):
0809 pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0810 pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0811 true_id, true_charge, true_momentum = separate_truth(y_true)
0812
0813 msk = true_id[:, :, 0]!=0
0814 return tf.reduce_mean(mse_unreduced(tf.math.floormod(true_momentum[msk][:, 1] - pred_momentum[msk][:, 1] + np.pi, 2*np.pi) - np.pi, 0.0))
0815
0816 def energy_resolution(y_true, y_pred):
0817 pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0818 pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0819 true_id, true_charge, true_momentum = separate_truth(y_true)
0820
0821 msk = true_id[:, :, 0]!=0
0822 return tf.reduce_mean(mse_unreduced(true_momentum[msk][:, 2], pred_momentum[msk][:, 2]))
0823
0824 def get_unique_run():
0825 previous_runs = os.listdir('experiments')
0826 if len(previous_runs) == 0:
0827 run_number = 1
0828 else:
0829 run_number = max([int(s.split('run_')[1]) for s in previous_runs]) + 1
0830 return run_number
0831
0832 def parse_args():
0833 import argparse
0834 parser = argparse.ArgumentParser()
0835 parser.add_argument("--model", type=str, default="PFNet", help="type of model to train", choices=["PFNet"])
0836 parser.add_argument("--ntrain", type=int, default=100, help="number of training events")
0837 parser.add_argument("--ntest", type=int, default=100, help="number of testing events")
0838 parser.add_argument("--nepochs", type=int, default=100, help="number of training epochs")
0839 parser.add_argument("--hidden-dim-id", type=int, default=256, help="hidden dimension")
0840 parser.add_argument("--hidden-dim-reg", type=int, default=256, help="hidden dimension")
0841 parser.add_argument("--batch-size", type=int, default=1, help="number of events in training batch")
0842 parser.add_argument("--num-convs-id", type=int, default=3, help="number of convolution layers")
0843 parser.add_argument("--num-convs-reg", type=int, default=3, help="number of convolution layers")
0844 parser.add_argument("--num-hidden-id-enc", type=int, default=2, help="number of encoder layers for multiclass")
0845 parser.add_argument("--num-hidden-id-dec", type=int, default=2, help="number of decoder layers for multiclass")
0846 parser.add_argument("--num-hidden-reg-enc", type=int, default=2, help="number of encoder layers for regression")
0847 parser.add_argument("--num-hidden-reg-dec", type=int, default=2, help="number of decoder layers for regression")
0848 parser.add_argument("--num-neighbors", type=int, default=5, help="number of knn neighbors")
0849 parser.add_argument("--distance-dim", type=int, default=256, help="distance dimension")
0850 parser.add_argument("--bin-size", type=int, default=100, help="number of points per LSH bin")
0851 parser.add_argument("--dropout", type=float, default=0.1, help="Dropout rate")
0852 parser.add_argument("--dist-mult", type=float, default=1.0, help="Exponential multiplier")
0853 parser.add_argument("--target", type=str, choices=["cand", "gen"], help="Regress to PFCandidates or GenParticles", default="cand")
0854 parser.add_argument("--weights", type=str, choices=["uniform", "inverse", "classbalanced"], help="Sample weighting scheme to use", default="inverse")
0855 parser.add_argument("--name", type=str, default=None, help="where to store the output")
0856 parser.add_argument("--convlayer", type=str, default="ghconv", choices=["ghconv"], help="Type of graph convolutional layer")
0857 parser.add_argument("--load", type=str, default=None, help="model to load")
0858 parser.add_argument("--datapath", type=str, help="Input data path", required=True)
0859 parser.add_argument("--lr", type=float, default=1e-5, help="learning rate")
0860 parser.add_argument("--lr-decay", type=float, default=0.0, help="learning rate decay")
0861 parser.add_argument("--train-cls", action="store_true", help="Train only the classification part")
0862 parser.add_argument("--train-reg", action="store_true", help="Train only the regression part")
0863 parser.add_argument("--cosine-dist", action="store_true", help="Use cosine distance")
0864 parser.add_argument("--eager", action="store_true", help="Run in eager mode for debugging")
0865 args = parser.parse_args()
0866 return args
0867
0868 def assign_label(pred_id_onehot_linear):
0869 ret2 = np.argmax(pred_id_onehot_linear, axis=-1)
0870 return ret2
0871
0872 def prepare_df(model, data, outdir, target, save_raw=False):
0873 print("prepare_df")
0874
0875 dfs = []
0876 for iev, d in enumerate(data):
0877 if iev%50==0:
0878 tf.print(".", end="")
0879 X, y, w = d
0880 pred = model(X, training=False).numpy()
0881 pred_id_onehot, pred_charge, pred_momentum = separate_prediction(pred)
0882 pred_id = assign_label(pred_id_onehot).flatten()
0883
0884 if save_raw:
0885 np.savez_compressed("ev_{}.npz".format(iev), X=X.numpy(), y=y.numpy(), w=w.numpy(), y_pred=pred)
0886
0887 pred_charge = pred_charge[:, :, 0].flatten()
0888 pred_momentum = pred_momentum.reshape((pred_momentum.shape[0]*pred_momentum.shape[1], pred_momentum.shape[2]))
0889
0890 true_id, true_charge, true_momentum = separate_truth(y)
0891 true_id = true_id.numpy()[:, :, 0].flatten()
0892 true_charge = true_charge.numpy()[:, :, 0].flatten()
0893 true_momentum = true_momentum.numpy().reshape((true_momentum.shape[0]*true_momentum.shape[1], true_momentum.shape[2]))
0894
0895 df = pandas.DataFrame()
0896 df["pred_pid"] = np.array([int(class_labels[p]) for p in pred_id])
0897 df["pred_eta"] = np.array(pred_momentum[:, 0], dtype=np.float64)
0898 df["pred_phi"] = np.array(pred_momentum[:, 1], dtype=np.float64)
0899 df["pred_e"] = np.array(pred_momentum[:, 2], dtype=np.float64)
0900
0901 df["{}_pid".format(target)] = np.array([int(class_labels[p]) for p in true_id])
0902 df["{}_eta".format(target)] = np.array(true_momentum[:, 0], dtype=np.float64)
0903 df["{}_phi".format(target)] = np.array(true_momentum[:, 1], dtype=np.float64)
0904 df["{}_e".format(target)] = np.array(true_momentum[:, 2], dtype=np.float64)
0905
0906 df["iev"] = iev
0907 dfs += [df]
0908 df = pandas.concat(dfs, ignore_index=True)
0909 fn = outdir + "/df.pkl.bz2"
0910 df.to_pickle(fn)
0911 print("prepare_df done", fn)
0912
0913 def plot_to_image(figure):
0914 """Converts the matplotlib plot specified by 'figure' to a PNG image and
0915 returns it. The supplied figure is closed and inaccessible after this call."""
0916
0917 buf = io.BytesIO()
0918 plt.savefig(buf, format='png')
0919
0920
0921 buf.seek(0)
0922
0923 image = tf.image.decode_png(buf.getvalue(), channels=4)
0924
0925 image = tf.expand_dims(image, 0)
0926 return image
0927
0928 def load_dataset_ttbar(datapath, target):
0929 from tf_data import _parse_tfr_element
0930 path = datapath + "/tfr/{}/*.tfrecords".format(target)
0931 tfr_files = glob.glob(path)
0932 if len(tfr_files) == 0:
0933 raise Exception("Could not find any files in {}".format(path))
0934 dataset = tf.data.TFRecordDataset(tfr_files).map(_parse_tfr_element, num_parallel_calls=tf.data.experimental.AUTOTUNE)
0935 return dataset
0936
0937 if __name__ == "__main__":
0938 args = parse_args()
0939 print(args)
0940
0941 if comet_enabled:
0942 experiment = Experiment(project_name="particleflow_tf")
0943
0944
0945 tf.config.experimental_run_functions_eagerly(args.eager)
0946
0947
0948 global_batch_size = args.batch_size
0949 try:
0950 num_gpus = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
0951 print("num_gpus=", num_gpus)
0952 if num_gpus > 1:
0953 strategy = tf.distribute.MirroredStrategy()
0954 global_batch_size = num_gpus * args.batch_size
0955 else:
0956 strategy = tf.distribute.OneDeviceStrategy("gpu:0")
0957 except Exception as e:
0958 print("fallback to CPU")
0959 strategy = tf.distribute.OneDeviceStrategy("cpu")
0960
0961 filelist = sorted(glob.glob(args.datapath + "/raw/*.pkl"))[:args.ntrain+args.ntest]
0962
0963 dataset = load_dataset_ttbar(args.datapath, args.target)
0964
0965
0966 ps = (tf.TensorShape([num_max_elems, 15]), tf.TensorShape([num_max_elems, 5]), tf.TensorShape([num_max_elems, ]))
0967 ds_train = dataset.take(args.ntrain).map(weight_schemes[args.weights]).padded_batch(global_batch_size, padded_shapes=ps)
0968 ds_test = dataset.skip(args.ntrain).take(args.ntest).map(weight_schemes[args.weights]).padded_batch(global_batch_size, padded_shapes=ps)
0969
0970
0971 ds_train_r = ds_train.repeat(args.nepochs)
0972 ds_test_r = ds_test.repeat(args.nepochs)
0973
0974 if args.lr_decay > 0:
0975 lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
0976 args.lr,
0977 decay_steps=10*int(args.ntrain/global_batch_size),
0978 decay_rate=args.lr_decay
0979 )
0980 else:
0981 lr_schedule = args.lr
0982
0983 loss_fn = my_loss_full
0984
0985 with strategy.scope():
0986 opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
0987
0988 model = PFNet(
0989 hidden_dim_id=args.hidden_dim_id,
0990 hidden_dim_reg=args.hidden_dim_reg,
0991 num_convs_id=args.num_convs_id,
0992 num_convs_reg=args.num_convs_reg,
0993 num_hidden_id_enc=args.num_hidden_id_enc,
0994 num_hidden_id_dec=args.num_hidden_id_dec,
0995 num_hidden_reg_enc=args.num_hidden_reg_enc,
0996 num_hidden_reg_dec=args.num_hidden_reg_dec,
0997 distance_dim=args.distance_dim,
0998 convlayer=args.convlayer,
0999 dropout=args.dropout,
1000 bin_size=args.bin_size,
1001 num_neighbors=args.num_neighbors,
1002 dist_mult=args.dist_mult
1003 )
1004
1005 if args.train_cls:
1006 loss_fn = my_loss_cls
1007 model.set_trainable_classification()
1008 elif args.train_reg:
1009 loss_fn = my_loss_reg
1010 model.set_trainable_regression()
1011
1012 model(np.random.randn(args.batch_size, num_max_elems, 15).astype(np.float32))
1013 if not args.eager:
1014 model = model.create_model(num_max_elems)
1015 model.summary()
1016
1017 if not os.path.isdir("experiments"):
1018 os.makedirs("experiments")
1019
1020 if args.name is None:
1021 args.name = 'run_{:02}'.format(get_unique_run())
1022
1023 outdir = 'experiments/' + args.name
1024
1025 if os.path.isdir(outdir):
1026 print("Output directory exists: {}".format(outdir), file=sys.stderr)
1027 sys.exit(1)
1028
1029 print(outdir)
1030 callbacks = []
1031 tb = tf.keras.callbacks.TensorBoard(
1032 log_dir=outdir, histogram_freq=0, write_graph=False, write_images=False,
1033 update_freq='epoch',
1034
1035 profile_batch=0,
1036 )
1037 tb.set_model(model)
1038 callbacks += [tb]
1039
1040 terminate_cb = tf.keras.callbacks.TerminateOnNaN()
1041 callbacks += [terminate_cb]
1042
1043 cp_callback = tf.keras.callbacks.ModelCheckpoint(
1044 filepath=outdir + "/weights.{epoch:02d}-{val_loss:.6f}.hdf5",
1045 save_weights_only=True,
1046 verbose=0
1047 )
1048 cp_callback.set_model(model)
1049 callbacks += [cp_callback]
1050
1051 with strategy.scope():
1052 model.compile(optimizer=opt, loss=loss_fn,
1053 metrics=[accuracy, cls_130, cls_211, cls_22, energy_resolution, eta_resolution, phi_resolution],
1054 sample_weight_mode="temporal")
1055
1056 if args.load:
1057
1058 for X, y, w in ds_train:
1059 model(X)
1060 break
1061
1062 model.load_weights(args.load)
1063
1064 if args.nepochs > 0:
1065 ret = model.fit(ds_train_r,
1066 validation_data=ds_test_r, epochs=args.nepochs,
1067 steps_per_epoch=args.ntrain/global_batch_size, validation_steps=args.ntest/global_batch_size,
1068 verbose=True,
1069 callbacks=callbacks
1070 )