Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:27:33

0001 import os
0002 import sys
0003 import random
0004 os.environ["KERAS_BACKEND"] = "tensorflow"
0005 
0006 import glob
0007 try:
0008     if not ("CUDA_VISIBLE_DEVICES" in os.environ):
0009         print("importing setGPU")
0010         import setGPU
0011 except:
0012     print("Could not import setGPU, please make sure you configure CUDA_VISIBLE_DEVICES manually")
0013     pass
0014 
0015 try:
0016     from comet_ml import Experiment
0017     comet_enabled = True
0018 except ImportError as e:
0019     print("could not import comet, online dashboard disabled")
0020     comet_enabled = False
0021  
0022 import pickle
0023 import matplotlib.pyplot as plt
0024 import numpy as np
0025 from sklearn.metrics import confusion_matrix, accuracy_score
0026 import pandas
0027 import time
0028 import itertools
0029 import io
0030 import tensorflow as tf
0031 
0032 #physical_devices = tf.config.list_physical_devices('GPU')
0033 #tf.config.experimental.set_memory_growth(physical_devices[0], True)
0034 
0035 from numpy.lib.recfunctions import append_fields
0036 
0037 elem_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
0038 class_labels = [0, 1, 2, 11, 13, 22, 130, 211]
0039 
0040 num_max_elems = 5000
0041 
0042 mult_classification_loss = 1e3
0043 mult_charge_loss = 1.0
0044 mult_energy_loss = 10.0
0045 mult_phi_loss = 10.0
0046 mult_eta_loss = 10.0
0047 mult_total_loss = 1e3
0048 
0049 def split_indices_to_bins(cmul, nbins, bin_size):
0050     bin_idx = tf.argmax(cmul, axis=-1)
0051     bins_split = tf.reshape(tf.argsort(bin_idx), (nbins, bin_size))
0052     return bins_split
0053 
0054 def pairwise_dist(A, B):  
0055     na = tf.reduce_sum(tf.square(A), -1)
0056     nb = tf.reduce_sum(tf.square(B), -1)
0057 
0058     # na as a row and nb as a column vectors
0059     na = tf.expand_dims(na, -1)
0060     nb = tf.expand_dims(nb, -2)
0061 
0062     # return pairwise euclidead difference matrix
0063     D = tf.sqrt(tf.maximum(na - 2*tf.matmul(A, B, False, True) + nb, 1e-6))
0064     return D
0065 
0066 """
0067 sp_a: (nbatch, nelem, nelem) sparse distance matrices
0068 b: (nbatch, nelem, ncol) dense per-element feature matrices
0069 """
0070 def sparse_dense_matmult_batch(sp_a, b):
0071 
0072     num_batches = tf.shape(b)[0]
0073     def map_function(x):
0074         i, dense_slice = x[0], x[1]
0075         num_points = tf.shape(b)[1]
0076 
0077         sparse_slice = tf.sparse.reshape(tf.sparse.slice(
0078             sp_a, [i, 0, 0], [1, num_points, num_points]),
0079             [num_points, num_points])
0080         mult_slice = tf.sparse.sparse_dense_matmul(sparse_slice, dense_slice)
0081         return mult_slice
0082 
0083     elems = (tf.range(0, tf.cast(num_batches, tf.int64), delta=1, dtype=tf.int64), b)
0084     ret = tf.map_fn(map_function, elems, fn_output_signature=tf.float32, back_prop=True)
0085     return ret 
0086 
0087 
0088 def summarize_dataset(dataset):
0089     yclasses = []
0090     nev = 0.0
0091     ntot = 0.0
0092     sizes = []
0093 
0094     for X, y, w in dataset:
0095         yclasses += [y[:, 0]]
0096         nev += 1
0097         ntot += len(y)
0098         sizes += [len(y)]
0099     
0100     yclasses = np.concatenate(yclasses)
0101     values, counts= np.unique(yclasses, return_counts=True)
0102     print("nev={}".format(nev))
0103     print("sizes={}".format(np.percentile(sizes, [25, 50, 95, 99])))
0104     for v, c in zip(values, counts):
0105         print("label={} count={} frac={:.6f}".format(class_labels[int(v)], c, c/ntot))
0106 
0107 #https://arxiv.org/pdf/1901.05555.pdf
0108 beta = 0.9999 #beta -> 1 means weight by inverse frequency, beta -> 0 means no reweighting
0109 def compute_weights_classbalanced(X, y, w):
0110     wn = (1.0 - beta)/(1.0 - tf.pow(beta, w))
0111     wn /= tf.reduce_sum(wn)
0112     return X, y, wn
0113 
0114 #uniform weights
0115 def compute_weights_uniform(X, y, w):
0116     wn = tf.ones_like(w)
0117     wn /= tf.reduce_sum(wn)
0118     return X, y, wn
0119 
0120 #weight proportional to 1/sqrt(N)
0121 def compute_weights_inverse(X, y, w):
0122     wn = 1.0/tf.sqrt(w)
0123     wn /= tf.reduce_sum(wn)
0124     return X, y, wn
0125 
0126 weight_schemes = {
0127     "uniform": compute_weights_uniform,
0128     "inverse": compute_weights_inverse,
0129     "classbalanced": compute_weights_classbalanced,
0130 }
0131 
0132 def load_one_file(fn):
0133     Xs = []
0134     ys = []
0135     ys_cand = []
0136     dms = []
0137 
0138     data = pickle.load(open(fn, "rb"), encoding='iso-8859-1')
0139     for event in data:
0140         Xelem = event["Xelem"]
0141         ygen = event["ygen"]
0142         ycand = event["ycand"]
0143 
0144         #remove PS from inputs, they don't seem to be very useful
0145         msk_ps = (Xelem["typ"] == 2) | (Xelem["typ"] == 3)
0146 
0147         Xelem = Xelem[~msk_ps]
0148         ygen = ygen[~msk_ps]
0149         ycand = ycand[~msk_ps]
0150 
0151         Xelem = append_fields(Xelem, "typ_idx", np.array([elem_labels.index(int(i)) for i in Xelem["typ"]], dtype=np.float32))
0152         ygen = append_fields(ygen, "typ_idx", np.array([class_labels.index(abs(int(i))) for i in ygen["typ"]], dtype=np.float32))
0153         ycand = append_fields(ycand, "typ_idx", np.array([class_labels.index(abs(int(i))) for i in ycand["typ"]], dtype=np.float32))
0154     
0155         Xelem_flat = np.stack([Xelem[k].view(np.float32).data for k in [
0156             'typ_idx',
0157             'pt', 'eta', 'phi', 'e',
0158             'layer', 'depth', 'charge', 'trajpoint',
0159             'eta_ecal', 'phi_ecal', 'eta_hcal', 'phi_hcal',
0160             'muon_dt_hits', 'muon_csc_hits']], axis=-1
0161         )
0162         ygen_flat = np.stack([ygen[k].view(np.float32).data for k in [
0163             'typ_idx',
0164             'eta', 'phi', 'e', 'charge',
0165             ]], axis=-1
0166         )
0167         ycand_flat = np.stack([ycand[k].view(np.float32).data for k in [
0168             'typ_idx',
0169             'eta', 'phi', 'e', 'charge',
0170             ]], axis=-1
0171         )
0172 
0173         #take care of outliers
0174         Xelem_flat[np.isnan(Xelem_flat)] = 0
0175         Xelem_flat[np.abs(Xelem_flat) > 1e4] = 0
0176         ygen_flat[np.isnan(ygen_flat)] = 0
0177         ygen_flat[np.abs(ygen_flat) > 1e4] = 0
0178         ycand_flat[np.isnan(ycand_flat)] = 0
0179         ycand_flat[np.abs(ycand_flat) > 1e4] = 0
0180 
0181         Xs += [Xelem_flat[:num_max_elems]]
0182         ys += [ygen_flat[:num_max_elems]]
0183         ys_cand += [ycand_flat[:num_max_elems]]
0184     
0185     print("created {} blocks, max size {}".format(len(Xs), max([len(X) for X in Xs])))
0186     return Xs, ys, ys_cand
0187 
0188 
0189 class InputEncoding(tf.keras.layers.Layer):
0190     def __init__(self, num_input_classes):
0191         super(InputEncoding, self).__init__()
0192         self.num_input_classes = num_input_classes
0193 
0194     """
0195         X: [Nbatch, Nelem, Nfeat] array of all the input detector element feature data
0196     """        
0197     def call(self, X):
0198 
0199         #X[:, :, 0] - categorical index of the element type
0200         Xid = tf.cast(tf.one_hot(tf.cast(X[:, :, 0], tf.int32), self.num_input_classes), dtype=tf.float32)
0201 
0202         #X[:, :, 1:] - all the other non-categorical features
0203         Xprop = X[:, :, 1:]
0204         return tf.concat([Xid, Xprop], axis=-1)
0205 
0206 #https://arxiv.org/pdf/2004.04635.pdf
0207 #https://github.com/gcucurull/jax-ghnet/blob/master/models.py 
0208 class GHConv(tf.keras.layers.Layer):
0209     def __init__(self, *args, **kwargs):
0210         self.activation = kwargs.pop("activation")
0211         self.hidden_dim = args[0]
0212 
0213         super(GHConv, self).__init__(*args, **kwargs)
0214 
0215     def build(self, input_shape):
0216         self.W_t = self.add_weight(shape=(self.hidden_dim, self.hidden_dim), name="w_t", initializer="random_normal")
0217         self.b_t = self.add_weight(shape=(self.hidden_dim, ), name="b_t", initializer="random_normal")
0218         self.W_h = self.add_weight(shape=(self.hidden_dim, self.hidden_dim), name="w_h", initializer="random_normal")
0219         self.theta = self.add_weight(shape=(self.hidden_dim, self.hidden_dim), name="theta", initializer="random_normal")
0220  
0221     def call(self, inputs):
0222         x, adj = inputs
0223 
0224         #compute the normalization of the adjacency matrix
0225         in_degrees = tf.sparse.reduce_sum(adj, axis=-1)
0226         in_degrees = tf.reshape(in_degrees, (tf.shape(x)[0], tf.shape(x)[1]))
0227 
0228         #add epsilon to prevent numerical issues from 1/sqrt(x)
0229         norm = tf.expand_dims(tf.pow(in_degrees + 1e-6, -0.5), -1)
0230 
0231         f_hom = tf.linalg.matmul(x, self.theta)
0232         f_hom = sparse_dense_matmult_batch(adj, f_hom*norm)*norm
0233 
0234         f_het = tf.linalg.matmul(x, self.W_h)
0235         gate = tf.nn.sigmoid(tf.linalg.matmul(x, self.W_t) + self.b_t)
0236 
0237         out = gate*f_hom + (1-gate)*f_het
0238         return self.activation(out)
0239 
0240 class GHConvDense(tf.keras.layers.Layer):
0241     def __init__(self, *args, **kwargs):
0242         self.activation = kwargs.pop("activation")
0243         self.hidden_dim = args[0]
0244         super(GHConvDense, self).__init__(*args, **kwargs)
0245 
0246     def build(self, input_shape):
0247         self.W_t = self.add_weight(shape=(self.hidden_dim, self.hidden_dim), name="w_t", initializer="random_normal")
0248         self.b_t = self.add_weight(shape=(self.hidden_dim, ), name="b_t", initializer="random_normal")
0249         self.W_h = self.add_weight(shape=(self.hidden_dim, self.hidden_dim), name="w_h", initializer="random_normal")
0250         self.theta = self.add_weight(shape=(self.hidden_dim, self.hidden_dim), name="theta", initializer="random_normal")
0251  
0252     def call(self, inputs):
0253         x, adj = inputs
0254 
0255         #compute the normalization of the adjacency matrix
0256         in_degrees = tf.reduce_sum(adj, axis=-1)
0257         in_degrees = tf.reshape(in_degrees, (tf.shape(x)[0], tf.shape(x)[1]))
0258 
0259         #add epsilon to prevent numerical issues from 1/sqrt(x)
0260         norm = tf.expand_dims(tf.pow(in_degrees + 1e-6, -0.5), -1)
0261 
0262         f_hom = tf.linalg.matmul(x, self.theta)
0263         f_hom = tf.linalg.matmul(adj, f_hom*norm)*norm
0264 
0265         f_het = tf.linalg.matmul(x, self.W_h)
0266         gate = tf.nn.sigmoid(tf.linalg.matmul(x, self.W_t) + self.b_t)
0267 
0268         out = gate*f_hom + (1-gate)*f_het
0269         return self.activation(out)
0270 
0271 class DenseDistance(tf.keras.layers.Layer):
0272     def __init__(self, dist_mult=0.1, **kwargs):
0273         super(DenseDistance, self).__init__(**kwargs)
0274         self.dist_mult = dist_mult
0275    
0276     def call(self, inputs, training=True):
0277         dm = pairwise_dist(inputs, inputs)
0278         dm = tf.exp(-self.dist_mult*dm)
0279         return dm 
0280 
0281 class SparseHashedNNDistance(tf.keras.layers.Layer):
0282     def __init__(self, max_num_bins=200, bin_size=500, num_neighbors=5, dist_mult=0.1, cosine_dist=False, **kwargs):
0283         super(SparseHashedNNDistance, self).__init__(**kwargs)
0284         self.num_neighbors = num_neighbors
0285         self.dist_mult = dist_mult
0286 
0287         self.cosine_dist = cosine_dist
0288 
0289         #generate the codebook for LSH hashing at model instantiation for up to this many bins
0290         #set this to a high-enough value at model generation to take into account the largest possible input 
0291         self.max_num_bins = max_num_bins
0292 
0293         #each bin will receive this many input elements, in total we can accept max_num_bins*bin_size input elements
0294         #in each bin, we will do a dense top_k evaluation
0295         self.bin_size = bin_size
0296 
0297     def build(self, input_shape):
0298         #(n_batch, n_points, n_features)
0299 
0300         #generate the LSH codebook for random rotations (num_features, num_bins/2)
0301         self.codebook_random_rotations = self.add_weight(
0302             shape=(input_shape[-1], self.max_num_bins//2), initializer="random_normal", trainable=False, name="lsh_projections"
0303         )
0304 
0305     def call(self, inputs, training=True):
0306 
0307         #(n_batch, n_points, n_features)
0308         point_embedding = inputs
0309 
0310         n_batches = tf.shape(point_embedding)[0]
0311         n_points = tf.shape(point_embedding)[1]
0312 
0313         #cannot concat sparse tensors directly as that incorrectly destroys the gradient, see
0314         #https://github.com/tensorflow/tensorflow/blob/df3a3375941b9e920667acfe72fb4c33a8f45503/tensorflow/python/ops/sparse_grad.py#L33
0315         #therefore, for training, we implement sparse concatenation by hand 
0316         indices_all = []
0317         values_all = []
0318 
0319         def func(args):
0320             ibatch, points_batch = args[0], args[1]
0321             dm = self.construct_sparse_dm_batch(points_batch)
0322             inds = tf.concat([tf.expand_dims(tf.cast(ibatch, tf.int64)*tf.ones(tf.shape(dm.indices)[0], dtype=tf.int64), -1), dm.indices], axis=-1)
0323             vals = dm.values
0324             return inds, vals
0325 
0326         elems = (tf.range(0, tf.cast(n_batches, tf.int64), delta=1, dtype=tf.int64), point_embedding)
0327         ret = tf.map_fn(func, elems, fn_output_signature=(tf.int64, tf.float32), parallel_iterations=1)
0328         shp = tf.shape(ret[0])
0329         # #now create a new SparseTensor that is a concatenation of the previous ones
0330         dms = tf.SparseTensor(
0331             tf.reshape(ret[0], (shp[0]*shp[1], shp[2])),
0332             tf.reshape(ret[1], (shp[0]*shp[1],)),
0333             (n_batches, n_points, n_points)
0334         )
0335 
0336         return tf.sparse.reorder(dms)
0337 
0338     def subpoints_to_sparse_matrix(self, n_points, subindices, subpoints):
0339 
0340         #find the distance matrix between the given points using dense matrix multiplication
0341         if self.cosine_dist:
0342             normed = tf.nn.l2_normalize(subpoints, axis=-1)
0343             dm = tf.linalg.matmul(subpoints, subpoints, transpose_b=True)
0344         else:
0345             dm = pairwise_dist(subpoints, subpoints)
0346             dm = tf.exp(-self.dist_mult*dm)
0347 
0348         dmshape = tf.shape(dm)
0349         nbins = dmshape[0]
0350         nelems = dmshape[1]
0351 
0352         #run KNN in the dense distance matrix, accumulate each index pair into a sparse distance matrix
0353         top_k = tf.nn.top_k(dm, k=self.num_neighbors)
0354         top_k_vals = tf.reshape(top_k.values, (nbins*nelems, self.num_neighbors))
0355 
0356         indices_gathered = tf.vectorized_map(
0357             lambda i: tf.gather_nd(subindices, top_k.indices[:, :, i:i+1], batch_dims=1),
0358             tf.range(self.num_neighbors, dtype=tf.int64))
0359 
0360         indices_gathered = tf.transpose(indices_gathered, [1,2,0])
0361 
0362         #add the neighbors up to a big matrix using dense matrices, then convert to sparse (mainly for testing)
0363         # sp_sum = tf.zeros((n_points, n_points))
0364         # for i in range(self.num_neighbors):
0365         #     dst_ind = indices_gathered[:, :, i] #(nbins, nelems)
0366         #     dst_ind = tf.reshape(dst_ind, (nbins*nelems, ))
0367         #     src_ind = tf.reshape(tf.stack(subindices), (nbins*nelems, ))
0368         #     src_dst_inds = tf.transpose(tf.stack([src_ind, dst_ind]))
0369         #     sp_sum += tf.scatter_nd(src_dst_inds, top_k_vals[:, i], (n_points, n_points))
0370         # spt_this = tf.sparse.from_dense(sp_sum)
0371         # validate that the vectorized ops are doing what we want by hand while debugging
0372         # dm = np.eye(n_points)
0373         # for ibin in range(nbins):
0374         #     for ielem in range(nelems):
0375         #         idx0 = subindices[ibin][ielem]
0376         #         for ineigh in range(self.num_neighbors):
0377         #             idx1 = subindices[ibin][top_k.indices[ibin, ielem, ineigh]]
0378         #             val = top_k.values[ibin, ielem, ineigh]
0379         #             dm[idx0, idx1] += val
0380         # assert(np.all(sp_sum.numpy() == dm))
0381 
0382         #update the output using intermediate sparse matrices, which may result in some inconsistencies from duplicated indices
0383         sp_sum = tf.sparse.SparseTensor(indices=tf.zeros((0,2), dtype=tf.int64), values=tf.zeros(0, tf.float32), dense_shape=(n_points, n_points))
0384         for i in range(self.num_neighbors):
0385            dst_ind = indices_gathered[:, :, i] #(nbins, nelems)
0386            dst_ind = tf.reshape(dst_ind, (nbins*nelems, ))
0387            src_ind = tf.reshape(tf.stack(subindices), (nbins*nelems, ))
0388            src_dst_inds = tf.cast(tf.transpose(tf.stack([src_ind, dst_ind])), dtype=tf.int64)
0389            sp_sum = tf.sparse.add(
0390                sp_sum,
0391                tf.sparse.reorder(tf.sparse.SparseTensor(src_dst_inds, top_k_vals[:, i], (n_points, n_points)))
0392            )
0393         spt_this = tf.sparse.reorder(sp_sum)
0394 
0395         return spt_this
0396 
0397     def construct_sparse_dm_batch(self, points):
0398 
0399         #points: (n_points, n_features) input elements for graph construction
0400         n_points = tf.shape(points)[0]
0401         n_features = tf.shape(points)[1]
0402 
0403         #compute the number of LSH bins to divide the input points into on the fly
0404         #n_points must be divisible by bin_size exactly due to the use of reshape
0405         n_bins = tf.math.floordiv(n_points, self.bin_size)
0406         #tf.debugging.assert_greater(n_bins, 0)
0407 
0408         #put each input item into a bin defined by the softmax output across the LSH embedding
0409         mul = tf.linalg.matmul(points, self.codebook_random_rotations[:, :n_bins//2])
0410         #tf.debugging.assert_greater(tf.shape(mul)[2], 0)
0411 
0412         cmul = tf.concat([mul, -mul], axis=-1)
0413 
0414         #cmul is now an integer in [0..nbins) for each input point
0415         #bins_split: (n_bins, bin_size) of integer bin indices, which put each input point into a bin of size (n_points/n_bins)
0416         bins_split = split_indices_to_bins(cmul, n_bins, self.bin_size)
0417 
0418         #parts: (n_bins, bin_size, n_features), the input points divided up into bins
0419         parts = tf.gather(points, bins_split)
0420 
0421         #sparse_distance_matrix: (n_points, n_points) sparse distance matrix
0422         #where higher values (closer to 1) are associated with points that are closely related
0423         sparse_distance_matrix = self.subpoints_to_sparse_matrix(n_points, bins_split, parts)
0424 
0425         return sparse_distance_matrix
0426 
0427 class EncoderDecoderGNN(tf.keras.layers.Layer):
0428     def __init__(self, encoders, decoders, dropout, activation, conv, **kwargs):
0429         super(EncoderDecoderGNN, self).__init__(**kwargs)
0430         name = kwargs.get("name")
0431 
0432         #assert(encoders[-1] == decoders[0])
0433         self.encoders = encoders
0434         self.decoders = decoders
0435 
0436         self.encoding_layers = []
0437         for ilayer, nunits in enumerate(encoders):
0438             self.encoding_layers.append(
0439                 tf.keras.layers.Dense(nunits, activation=activation, name="encoding_{}_{}".format(name, ilayer)))
0440             if dropout > 0.0:
0441                 self.encoding_layers.append(tf.keras.layers.Dropout(dropout))
0442 
0443         self.conv = conv
0444 
0445         self.decoding_layers = []
0446         for ilayer, nunits in enumerate(decoders):
0447             self.decoding_layers.append(
0448                 tf.keras.layers.Dense(nunits, activation=activation, name="decoding_{}_{}".format(name, ilayer)))
0449             if dropout > 0.0:
0450                 self.decoding_layers.append(tf.keras.layers.Dropout(dropout))
0451 
0452     def call(self, inputs, distance_matrix, training=True):
0453         x = inputs
0454 
0455         for layer in self.encoding_layers:
0456             x = layer(x)
0457 
0458         for convlayer in self.conv:
0459             x = convlayer([x, distance_matrix])
0460 
0461         for layer in self.decoding_layers:
0462             x = layer(x)
0463 
0464         return x
0465 
0466 class AddSparse(tf.keras.layers.Layer):
0467     def __init__(self, **kwargs):
0468         super(AddSparse, self).__init__(**kwargs)
0469 
0470     def call(self, matrices):
0471         ret = matrices[0]
0472         for mat in matrices[1:]:
0473             ret = tf.sparse.add(ret, mat)
0474         return ret
0475 
0476 #Simple message passing based on a matrix multiplication
0477 class PFNet(tf.keras.Model):
0478     def __init__(self,
0479         activation=tf.nn.selu,
0480         hidden_dim_id=256,
0481         hidden_dim_reg=256,
0482         distance_dim=256,
0483         convlayer="ghconv",
0484         dropout=0.1,
0485         bin_size=10,
0486         num_convs_id=1,
0487         num_convs_reg=1,
0488         num_hidden_id_enc=1,
0489         num_hidden_id_dec=1,
0490         num_hidden_reg_enc=1,
0491         num_hidden_reg_dec=1,
0492         num_neighbors=5,
0493         dist_mult=0.1,
0494         cosine_dist=False):
0495 
0496         super(PFNet, self).__init__()
0497         self.activation = activation
0498         self.num_dists = 1
0499 
0500         encoding_id = []
0501         decoding_id = []
0502         encoding_reg = []
0503         decoding_reg = []
0504 
0505         #the encoder outputs and decoder inputs have to have the hidden dim (convlayer size)
0506         for ihidden in range(num_hidden_id_enc):
0507             encoding_id.append(hidden_dim_id)
0508 
0509         for ihidden in range(num_hidden_id_dec):
0510             decoding_id.append(hidden_dim_id)
0511 
0512         for ihidden in range(num_hidden_reg_enc):
0513             encoding_reg.append(hidden_dim_reg)
0514 
0515         for ihidden in range(num_hidden_reg_dec):
0516             decoding_reg.append(hidden_dim_reg)
0517 
0518         self.enc = InputEncoding(len(elem_labels))
0519         self.layer_embedding = tf.keras.layers.Dense(distance_dim, name="embedding_attention")
0520         
0521         self.embedding_dropout = None
0522         if dropout > 0.0:
0523             self.embedding_dropout = tf.keras.layers.Dropout(dropout)
0524 
0525         self.dists = []
0526         for idist in range(self.num_dists):
0527             self.dists.append(SparseHashedNNDistance(bin_size=bin_size, num_neighbors=num_neighbors, dist_mult=dist_mult, cosine_dist=cosine_dist))
0528         self.addsparse = AddSparse()
0529         #self.dist = DenseDistance(dist_mult=dist_mult)
0530 
0531         convs_id = []
0532         convs_reg = []
0533 
0534         for iconv in range(num_convs_id):
0535             convs_id.append(GHConv(26 if len(encoding_id)==0 else hidden_dim_id, activation=activation, name="conv_id{}".format(iconv)))
0536         for iconv in range(num_convs_reg):
0537             convs_reg.append(GHConv(35 if len(encoding_reg)==0 else hidden_dim_reg, activation=activation, name="conv_reg{}".format(iconv)))
0538 
0539         self.gnn_id = EncoderDecoderGNN(encoding_id, decoding_id, dropout, activation, convs_id, name="gnn_id")
0540         self.layer_id = tf.keras.layers.Dense(len(class_labels), activation="linear", name="out_id")
0541         self.layer_charge = tf.keras.layers.Dense(1, activation="linear", name="out_charge")
0542         
0543         self.gnn_reg = EncoderDecoderGNN(encoding_reg, decoding_reg, dropout, activation, convs_reg, name="gnn_reg")
0544         self.layer_momentum = tf.keras.layers.Dense(3, activation="linear", name="out_momentum")
0545 
0546     def create_model(self, num_max_elems, training=True):
0547         inputs = tf.keras.Input(shape=(num_max_elems,15,))
0548         return tf.keras.Model(inputs=[inputs], outputs=self.call(inputs, training), name="MLPFNet")
0549 
0550     def call(self, inputs, training=True):
0551         X = tf.cast(inputs, tf.float32)
0552         msk_input = tf.expand_dims(tf.cast(X[:, :, 0] != 0, tf.float32), -1)
0553 
0554         enc = self.enc(inputs)
0555 
0556         #embed inputs for graph structure prediction
0557         embedding_attention = self.layer_embedding(enc)
0558         if self.embedding_dropout:
0559             embedding_attention = self.embedding_dropout(embedding_attention, training)
0560 
0561         #create graph structure by predicting a sparse distance matrix
0562         dms = [dist(embedding_attention, training) for dist in self.dists]
0563         dm = self.addsparse(dms)
0564 
0565         #run graph net for multiclass id prediction
0566         x_id = self.gnn_id(enc, dm, training)
0567         to_decode = tf.concat([enc, x_id], axis=-1)
0568         out_id_logits = self.layer_id(to_decode)
0569         out_charge = self.layer_charge(to_decode)
0570 
0571         #run graph net for regression output prediction, taking as an additonal input the ID predictions
0572         x_reg = self.gnn_reg(tf.concat([enc, out_id_logits, out_charge], axis=-1), dm, training)
0573         to_decode = tf.concat([enc, x_reg], axis=-1)
0574         pred_corr = self.layer_momentum(to_decode)
0575 
0576         #soft-mask elements for which the id prediction was 0  
0577         probabilistic_mask_good = 1.0 - tf.keras.activations.softmax(out_id_logits)[:, :, 0]
0578 
0579         out_momentum_eta = X[:, :, 2] + pred_corr[:, :, 0]
0580         out_momentum_phi = X[:, :, 3] + pred_corr[:, :, 1] 
0581         out_momentum_E = X[:, :, 4] + pred_corr[:, :, 2]
0582 
0583         out_momentum = tf.stack([
0584             out_momentum_eta * probabilistic_mask_good,
0585             out_momentum_phi * probabilistic_mask_good,
0586             out_momentum_E * probabilistic_mask_good,
0587         ], axis=-1)
0588 
0589         ret = tf.concat([out_id_logits, out_momentum, out_charge], axis=-1)*msk_input
0590         return ret
0591 
0592     def set_trainable_classification(self):
0593         self.gnn_reg.trainable = False
0594         self.layer_momentum.trainable = False
0595 
0596     def set_trainable_regression(self):
0597         for layer in self.layers:
0598             layer.trainable = False
0599         self.gnn_reg.trainable = True
0600         self.layer_momentum.trainable = True
0601 
0602 #Just a dummy elementwise model
0603 class PFNetDummy(tf.keras.Model):
0604     def __init__(self, **kwargs):
0605         super(PFNetDummy, self).__init__()
0606         self.enc = InputEncoding(len(elem_labels))
0607 
0608         self.flatten = tf.keras.layers.Flatten()
0609         self.layer_hidden0 = tf.keras.layers.Dense(32, activation="elu")
0610         self.layer_hidden1 = tf.keras.layers.Dense(64, activation="elu")
0611         self.layer_hidden2 = tf.keras.layers.Dense(128, activation="elu")
0612         self.layer_hidden3 = tf.keras.layers.Dense(256, activation="elu")
0613 
0614         self.layer_id = tf.keras.layers.Dense(len(class_labels), activation="linear", name="out_id")
0615         self.layer_charge = tf.keras.layers.Dense(1, activation="linear", name="out_charge")
0616         self.layer_momentum = tf.keras.layers.Dense(3, activation="linear", name="out_momentum")
0617 
0618     def call(self, inputs, training=True):
0619         X = tf.cast(inputs, tf.float32)
0620         msk_input = tf.expand_dims(tf.cast(X[:, :, 0] != 0, tf.float32), -1)
0621         enc = self.enc(inputs)
0622 
0623         h = self.layer_hidden0(flat)
0624         h = self.layer_hidden1(h)
0625         h = self.layer_hidden2(h)
0626         h = self.layer_hidden3(h)
0627 
0628         out_id_logits = self.layer_id(h)
0629         out_charge = self.layer_charge(h)
0630         pred_corr = self.layer_momentum(h)
0631 
0632         #soft-mask elements for which the id prediction was 0  
0633         probabilistic_mask_good = 1.0 - tf.keras.activations.softmax(out_id_logits)[:, :, 0]
0634 
0635         out_momentum_eta = X[:, :, 2] + pred_corr[:, :, 0]
0636         out_momentum_phi = X[:, :, 3] + pred_corr[:, :, 1] 
0637         out_momentum_E = X[:, :, 4] + pred_corr[:, :, 2]
0638 
0639         out_momentum = tf.stack([
0640             out_momentum_eta * probabilistic_mask_good,
0641             out_momentum_phi * probabilistic_mask_good,
0642             out_momentum_E * probabilistic_mask_good,
0643         ], axis=-1)
0644 
0645         ret = tf.concat([out_id_logits, out_momentum, out_charge], axis=-1)*msk_input
0646         return ret
0647 
0648     def set_trainable_classification(self):
0649         self.layer_momentum.trainable = False
0650 
0651     def set_trainable_regression(self):
0652         pass
0653 
0654     def create_model(self, num_max_elems, training=True):
0655         inputs = tf.keras.Input(shape=(num_max_elems,15,))
0656         return tf.keras.Model(inputs=[inputs], outputs=self.call(inputs, training), name="MLPFNet")
0657 
0658 def separate_prediction(y_pred):
0659     N = len(class_labels)
0660     pred_id_logits = y_pred[:, :, :N]
0661     pred_momentum = y_pred[:, :, N:N+3]
0662     pred_charge = y_pred[:, :, N+3:N+4]
0663     return pred_id_logits, pred_charge, pred_momentum
0664 
0665 def separate_truth(y_true):
0666     true_id = tf.cast(y_true[:, :, :1], tf.int32)
0667     true_momentum = y_true[:, :, 1:4]
0668     true_charge = y_true[:, :, 4:5]
0669     return true_id, true_charge, true_momentum
0670 
0671 def mse_unreduced(true, pred):
0672     return tf.math.pow(true-pred,2)
0673 
0674 def msle_unreduced(true, pred):
0675     return tf.math.pow(tf.math.log(tf.math.abs(true) + 1.0) - tf.math.log(tf.math.abs(pred) + 1.0), 2)
0676 
0677 def my_loss_cls(y_true, y_pred):
0678     pred_id_logits, pred_charge, _ = separate_prediction(y_pred)
0679     true_id, true_charge, _ = separate_truth(y_true)
0680 
0681     true_id_onehot = tf.one_hot(tf.cast(true_id, tf.int32), depth=len(class_labels))
0682     #predict the particle class labels
0683     l1 = mult_classification_loss*tf.nn.softmax_cross_entropy_with_logits(true_id_onehot, pred_id_logits)
0684     l3 = mult_charge_loss*mse_unreduced(true_charge, pred_charge)[:, :, 0]
0685 
0686     loss = l1 + l3
0687     return mult_total_loss*loss
0688 
0689 def my_loss_reg(y_true, y_pred):
0690     _, _, pred_momentum = separate_prediction(y_pred)
0691     _, true_charge, true_momentum = separate_truth(y_true)
0692 
0693     l2_0 = mult_eta_loss*mse_unreduced(true_momentum[:, :, 0], pred_momentum[:, :, 0])
0694     l2_1 = mult_phi_loss*mse_unreduced(tf.math.floormod(true_momentum[:, :, 1] - pred_momentum[:, :, 1] + np.pi, 2*np.pi) - np.pi, 0.0)
0695     l2_2 = mult_energy_loss*mse_unreduced(true_momentum[:, :, 2], pred_momentum[:, :, 2])
0696 
0697     loss = (l2_0 + l2_1 + l2_2)
0698     
0699     return 1e3*loss
0700 
0701 def my_loss_full(y_true, y_pred):
0702     pred_id_logits, pred_charge, pred_momentum = separate_prediction(y_pred)
0703     pred_id = tf.cast(tf.argmax(pred_id_logits, axis=-1), tf.int32)
0704     true_id, true_charge, true_momentum = separate_truth(y_true)
0705     true_id_onehot = tf.one_hot(tf.cast(true_id, tf.int32), depth=len(class_labels))
0706     
0707     l1 = mult_classification_loss*tf.nn.softmax_cross_entropy_with_logits(true_id_onehot, pred_id_logits)
0708   
0709     l2_0 = mult_eta_loss*mse_unreduced(true_momentum[:, :, 0], pred_momentum[:, :, 0])
0710     l2_1 = mult_phi_loss*mse_unreduced(tf.math.floormod(true_momentum[:, :, 1] - pred_momentum[:, :, 1] + np.pi, 2*np.pi) - np.pi, 0.0)
0711     l2_2 = mult_energy_loss*mse_unreduced(true_momentum[:, :, 2], pred_momentum[:, :, 2])
0712 
0713     l2 = (l2_0 + l2_1 + l2_2)
0714 
0715     l3 = mult_charge_loss*mse_unreduced(true_charge, pred_charge)[:, :, 0]
0716     loss = l1 + l2 + l3
0717 
0718     return mult_total_loss*loss
0719 
0720 #TODO: put these in a class
0721 def cls_130(y_true, y_pred):
0722     pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0723     pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0724     true_id, true_charge, true_momentum = separate_truth(y_true)
0725 
0726     msk_true = true_id[:, :, 0] == class_labels.index(130)
0727     msk_pos = pred_id == class_labels.index(130)
0728     num_true_pos = tf.reduce_sum(tf.cast(msk_true&msk_pos, tf.float32))
0729     num_true = tf.reduce_sum(tf.cast(msk_true, tf.float32))
0730     return num_true_pos/num_true
0731 
0732 def cls_211(y_true, y_pred):
0733     pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0734     pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0735     true_id, true_charge, true_momentum = separate_truth(y_true)
0736 
0737     msk_true = true_id[:, :, 0] == class_labels.index(211)
0738     msk_pos = pred_id == class_labels.index(211)
0739     num_true_pos = tf.reduce_sum(tf.cast(msk_true&msk_pos, tf.float32))
0740     num_true = tf.reduce_sum(tf.cast(msk_true, tf.float32))
0741 
0742     return num_true_pos/num_true
0743 
0744 def cls_22(y_true, y_pred):
0745     pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0746     pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0747     true_id, true_charge, true_momentum = separate_truth(y_true)
0748 
0749     msk_true = true_id[:, :, 0] == class_labels.index(22)
0750     msk_pos = pred_id == class_labels.index(22)
0751     num_true_pos = tf.reduce_sum(tf.cast(msk_true&msk_pos, tf.float32))
0752     num_true = tf.reduce_sum(tf.cast(msk_true, tf.float32))
0753 
0754     return num_true_pos/num_true
0755 
0756 def cls_11(y_true, y_pred):
0757     pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0758     pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0759     true_id, true_charge, true_momentum = separate_truth(y_true)
0760 
0761     msk_true = true_id[:, :, 0] == class_labels.index(11)
0762     msk_pos = pred_id == class_labels.index(11)
0763     num_true_pos = tf.reduce_sum(tf.cast(msk_true&msk_pos, tf.float32))
0764     num_true = tf.reduce_sum(tf.cast(msk_true, tf.float32))
0765 
0766     return num_true_pos/num_true
0767 
0768 def cls_13(y_true, y_pred):
0769     pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0770     pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0771     true_id, true_charge, true_momentum = separate_truth(y_true)
0772 
0773     msk_true = true_id[:, :, 0] == class_labels.index(13)
0774     msk_pos = pred_id == class_labels.index(13)
0775     num_true_pos = tf.reduce_sum(tf.cast(msk_true&msk_pos, tf.float32))
0776     num_true = tf.reduce_sum(tf.cast(msk_true, tf.float32))
0777 
0778     return num_true_pos/num_true
0779 
0780 def num_pred(y_true, y_pred):
0781     pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0782     pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0783     true_id, true_charge, true_momentum = separate_truth(y_true)
0784 
0785     ntrue = tf.reduce_sum(tf.cast(true_id[:, :, 0]!=0, tf.int32))
0786     npred = tf.reduce_sum(tf.cast(pred_id!=0, tf.int32))
0787     return tf.cast(ntrue - npred, tf.float32)
0788 
0789 def accuracy(y_true, y_pred):
0790     pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0791     pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0792     true_id, true_charge, true_momentum = separate_truth(y_true)
0793 
0794     is_true = true_id[:, :, 0]!=0
0795     is_same = true_id[:, :, 0] == pred_id
0796 
0797     acc = tf.reduce_sum(tf.cast(is_true&is_same, tf.int32)) / tf.reduce_sum(tf.cast(is_true, tf.int32))
0798     return tf.cast(acc, tf.float32)
0799 
0800 def eta_resolution(y_true, y_pred):
0801     pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0802     pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0803     true_id, true_charge, true_momentum = separate_truth(y_true)
0804 
0805     msk = true_id[:, :, 0]!=0
0806     return tf.reduce_mean(mse_unreduced(true_momentum[msk][:, 0], pred_momentum[msk][:, 0]))
0807 
0808 def phi_resolution(y_true, y_pred):
0809     pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0810     pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0811     true_id, true_charge, true_momentum = separate_truth(y_true)
0812 
0813     msk = true_id[:, :, 0]!=0
0814     return tf.reduce_mean(mse_unreduced(tf.math.floormod(true_momentum[msk][:, 1] - pred_momentum[msk][:, 1] + np.pi, 2*np.pi) - np.pi, 0.0))
0815 
0816 def energy_resolution(y_true, y_pred):
0817     pred_id_onehot, pred_charge, pred_momentum = separate_prediction(y_pred)
0818     pred_id = tf.cast(tf.argmax(pred_id_onehot, axis=-1), tf.int32)
0819     true_id, true_charge, true_momentum = separate_truth(y_true)
0820 
0821     msk = true_id[:, :, 0]!=0
0822     return tf.reduce_mean(mse_unreduced(true_momentum[msk][:, 2], pred_momentum[msk][:, 2]))
0823 
0824 def get_unique_run():
0825     previous_runs = os.listdir('experiments')
0826     if len(previous_runs) == 0:
0827         run_number = 1
0828     else:
0829         run_number = max([int(s.split('run_')[1]) for s in previous_runs]) + 1
0830     return run_number
0831 
0832 def parse_args():
0833     import argparse
0834     parser = argparse.ArgumentParser()
0835     parser.add_argument("--model", type=str, default="PFNet", help="type of model to train", choices=["PFNet"])
0836     parser.add_argument("--ntrain", type=int, default=100, help="number of training events")
0837     parser.add_argument("--ntest", type=int, default=100, help="number of testing events")
0838     parser.add_argument("--nepochs", type=int, default=100, help="number of training epochs")
0839     parser.add_argument("--hidden-dim-id", type=int, default=256, help="hidden dimension")
0840     parser.add_argument("--hidden-dim-reg", type=int, default=256, help="hidden dimension")
0841     parser.add_argument("--batch-size", type=int, default=1, help="number of events in training batch")
0842     parser.add_argument("--num-convs-id", type=int, default=3, help="number of convolution layers")
0843     parser.add_argument("--num-convs-reg", type=int, default=3, help="number of convolution layers")
0844     parser.add_argument("--num-hidden-id-enc", type=int, default=2, help="number of encoder layers for multiclass")
0845     parser.add_argument("--num-hidden-id-dec", type=int, default=2, help="number of decoder layers for multiclass")
0846     parser.add_argument("--num-hidden-reg-enc", type=int, default=2, help="number of encoder layers for regression")
0847     parser.add_argument("--num-hidden-reg-dec", type=int, default=2, help="number of decoder layers for regression")
0848     parser.add_argument("--num-neighbors", type=int, default=5, help="number of knn neighbors")
0849     parser.add_argument("--distance-dim", type=int, default=256, help="distance dimension")
0850     parser.add_argument("--bin-size", type=int, default=100, help="number of points per LSH bin")
0851     parser.add_argument("--dropout", type=float, default=0.1, help="Dropout rate")
0852     parser.add_argument("--dist-mult", type=float, default=1.0, help="Exponential multiplier")
0853     parser.add_argument("--target", type=str, choices=["cand", "gen"], help="Regress to PFCandidates or GenParticles", default="cand")
0854     parser.add_argument("--weights", type=str, choices=["uniform", "inverse", "classbalanced"], help="Sample weighting scheme to use", default="inverse")
0855     parser.add_argument("--name", type=str, default=None, help="where to store the output")
0856     parser.add_argument("--convlayer", type=str, default="ghconv", choices=["ghconv"], help="Type of graph convolutional layer")
0857     parser.add_argument("--load", type=str, default=None, help="model to load")
0858     parser.add_argument("--datapath", type=str, help="Input data path", required=True)
0859     parser.add_argument("--lr", type=float, default=1e-5, help="learning rate")
0860     parser.add_argument("--lr-decay", type=float, default=0.0, help="learning rate decay")
0861     parser.add_argument("--train-cls", action="store_true", help="Train only the classification part")
0862     parser.add_argument("--train-reg", action="store_true", help="Train only the regression part")
0863     parser.add_argument("--cosine-dist", action="store_true", help="Use cosine distance")
0864     parser.add_argument("--eager", action="store_true", help="Run in eager mode for debugging")
0865     args = parser.parse_args()
0866     return args
0867 
0868 def assign_label(pred_id_onehot_linear):
0869     ret2 = np.argmax(pred_id_onehot_linear, axis=-1)
0870     return ret2
0871 
0872 def prepare_df(model, data, outdir, target, save_raw=False):
0873     print("prepare_df")
0874 
0875     dfs = []
0876     for iev, d in enumerate(data):
0877         if iev%50==0:
0878             tf.print(".", end="")
0879         X, y, w = d
0880         pred = model(X, training=False).numpy()
0881         pred_id_onehot, pred_charge, pred_momentum = separate_prediction(pred)
0882         pred_id = assign_label(pred_id_onehot).flatten()
0883  
0884         if save_raw:
0885             np.savez_compressed("ev_{}.npz".format(iev), X=X.numpy(), y=y.numpy(), w=w.numpy(), y_pred=pred)
0886 
0887         pred_charge = pred_charge[:, :, 0].flatten()
0888         pred_momentum = pred_momentum.reshape((pred_momentum.shape[0]*pred_momentum.shape[1], pred_momentum.shape[2]))
0889 
0890         true_id, true_charge, true_momentum = separate_truth(y)
0891         true_id = true_id.numpy()[:, :, 0].flatten()
0892         true_charge = true_charge.numpy()[:, :, 0].flatten()
0893         true_momentum = true_momentum.numpy().reshape((true_momentum.shape[0]*true_momentum.shape[1], true_momentum.shape[2]))
0894        
0895         df = pandas.DataFrame()
0896         df["pred_pid"] = np.array([int(class_labels[p]) for p in pred_id])
0897         df["pred_eta"] = np.array(pred_momentum[:, 0], dtype=np.float64)
0898         df["pred_phi"] = np.array(pred_momentum[:, 1], dtype=np.float64)
0899         df["pred_e"] = np.array(pred_momentum[:, 2], dtype=np.float64)
0900 
0901         df["{}_pid".format(target)] = np.array([int(class_labels[p]) for p in true_id])
0902         df["{}_eta".format(target)] = np.array(true_momentum[:, 0], dtype=np.float64)
0903         df["{}_phi".format(target)] = np.array(true_momentum[:, 1], dtype=np.float64)
0904         df["{}_e".format(target)] = np.array(true_momentum[:, 2], dtype=np.float64)
0905 
0906         df["iev"] = iev
0907         dfs += [df]
0908     df = pandas.concat(dfs, ignore_index=True)
0909     fn = outdir + "/df.pkl.bz2"
0910     df.to_pickle(fn)
0911     print("prepare_df done", fn)
0912 
0913 def plot_to_image(figure):
0914     """Converts the matplotlib plot specified by 'figure' to a PNG image and
0915     returns it. The supplied figure is closed and inaccessible after this call."""
0916     # Save the plot to a PNG in memory.
0917     buf = io.BytesIO()
0918     plt.savefig(buf, format='png')
0919     # Closing the figure prevents it from being displayed directly inside
0920     # the notebook.
0921     buf.seek(0)
0922     # Convert PNG buffer to TF image
0923     image = tf.image.decode_png(buf.getvalue(), channels=4)
0924     # Add the batch dimension
0925     image = tf.expand_dims(image, 0)
0926     return image
0927 
0928 def load_dataset_ttbar(datapath, target):
0929     from tf_data import _parse_tfr_element
0930     path = datapath + "/tfr/{}/*.tfrecords".format(target)
0931     tfr_files = glob.glob(path)
0932     if len(tfr_files) == 0:
0933         raise Exception("Could not find any files in {}".format(path))
0934     dataset = tf.data.TFRecordDataset(tfr_files).map(_parse_tfr_element, num_parallel_calls=tf.data.experimental.AUTOTUNE)
0935     return dataset
0936 
0937 if __name__ == "__main__":
0938     args = parse_args()
0939     print(args)
0940    
0941     if comet_enabled: 
0942         experiment = Experiment(project_name="particleflow_tf")
0943 
0944     #tf.debugging.enable_check_numerics()
0945     tf.config.experimental_run_functions_eagerly(args.eager)
0946 
0947     #batch size for loading data must be configured according to the number of distributed GPUs 
0948     global_batch_size = args.batch_size
0949     try:
0950         num_gpus = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
0951         print("num_gpus=", num_gpus)
0952         if num_gpus > 1:
0953             strategy = tf.distribute.MirroredStrategy()
0954             global_batch_size = num_gpus * args.batch_size
0955         else:
0956             strategy = tf.distribute.OneDeviceStrategy("gpu:0")
0957     except Exception as e:
0958         print("fallback to CPU")
0959         strategy = tf.distribute.OneDeviceStrategy("cpu")
0960 
0961     filelist = sorted(glob.glob(args.datapath + "/raw/*.pkl"))[:args.ntrain+args.ntest]
0962 
0963     dataset = load_dataset_ttbar(args.datapath, args.target)
0964 
0965     #create padded input data
0966     ps = (tf.TensorShape([num_max_elems, 15]), tf.TensorShape([num_max_elems, 5]), tf.TensorShape([num_max_elems, ]))
0967     ds_train = dataset.take(args.ntrain).map(weight_schemes[args.weights]).padded_batch(global_batch_size, padded_shapes=ps)
0968     ds_test = dataset.skip(args.ntrain).take(args.ntest).map(weight_schemes[args.weights]).padded_batch(global_batch_size, padded_shapes=ps)
0969 
0970     #repeat needed for keras api
0971     ds_train_r = ds_train.repeat(args.nepochs)
0972     ds_test_r = ds_test.repeat(args.nepochs)
0973 
0974     if args.lr_decay > 0:
0975         lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
0976             args.lr,
0977             decay_steps=10*int(args.ntrain/global_batch_size),
0978             decay_rate=args.lr_decay
0979         )
0980     else:
0981         lr_schedule = args.lr
0982 
0983     loss_fn = my_loss_full
0984 
0985     with strategy.scope():
0986         opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
0987 
0988         model = PFNet(
0989             hidden_dim_id=args.hidden_dim_id,
0990             hidden_dim_reg=args.hidden_dim_reg,
0991             num_convs_id=args.num_convs_id,
0992             num_convs_reg=args.num_convs_reg,
0993             num_hidden_id_enc=args.num_hidden_id_enc,
0994             num_hidden_id_dec=args.num_hidden_id_dec,
0995             num_hidden_reg_enc=args.num_hidden_reg_enc,
0996             num_hidden_reg_dec=args.num_hidden_reg_dec,
0997             distance_dim=args.distance_dim,
0998             convlayer=args.convlayer,
0999             dropout=args.dropout,
1000             bin_size=args.bin_size,
1001             num_neighbors=args.num_neighbors,
1002             dist_mult=args.dist_mult
1003         )
1004 
1005         if args.train_cls:
1006             loss_fn = my_loss_cls
1007             model.set_trainable_classification()
1008         elif args.train_reg:
1009             loss_fn = my_loss_reg
1010             model.set_trainable_regression()
1011 
1012         model(np.random.randn(args.batch_size, num_max_elems, 15).astype(np.float32))
1013         if not args.eager:
1014             model = model.create_model(num_max_elems)
1015             model.summary()
1016 
1017     if not os.path.isdir("experiments"):
1018         os.makedirs("experiments")
1019 
1020     if args.name is None:
1021         args.name =  'run_{:02}'.format(get_unique_run())
1022 
1023     outdir = 'experiments/' + args.name
1024 
1025     if os.path.isdir(outdir):
1026         print("Output directory exists: {}".format(outdir), file=sys.stderr)
1027         sys.exit(1)
1028 
1029     print(outdir)
1030     callbacks = []
1031     tb = tf.keras.callbacks.TensorBoard(
1032         log_dir=outdir, histogram_freq=0, write_graph=False, write_images=False,
1033         update_freq='epoch',
1034         #profile_batch=(10,40),
1035         profile_batch=0,
1036     )
1037     tb.set_model(model)
1038     callbacks += [tb]
1039 
1040     terminate_cb = tf.keras.callbacks.TerminateOnNaN()
1041     callbacks += [terminate_cb]
1042 
1043     cp_callback = tf.keras.callbacks.ModelCheckpoint(
1044         filepath=outdir + "/weights.{epoch:02d}-{val_loss:.6f}.hdf5",
1045         save_weights_only=True,
1046         verbose=0
1047     )
1048     cp_callback.set_model(model)
1049     callbacks += [cp_callback]
1050 
1051     with strategy.scope():
1052         model.compile(optimizer=opt, loss=loss_fn,
1053             metrics=[accuracy, cls_130, cls_211, cls_22, energy_resolution, eta_resolution, phi_resolution],
1054             sample_weight_mode="temporal")
1055 
1056         if args.load:
1057             #ensure model input size is known
1058             for X, y, w in ds_train:
1059                 model(X)
1060                 break
1061    
1062             model.load_weights(args.load)
1063 
1064         if args.nepochs > 0:
1065             ret = model.fit(ds_train_r,
1066                 validation_data=ds_test_r, epochs=args.nepochs,
1067                 steps_per_epoch=args.ntrain/global_batch_size, validation_steps=args.ntest/global_batch_size,
1068                 verbose=True,
1069                 callbacks=callbacks
1070             )