File indexing completed on 2023-03-17 11:21:16
0001 import numpy as np
0002 import glob
0003 import multiprocessing
0004 import os
0005
0006 import tensorflow as tf
0007 from tf_model import load_one_file
0008
0009 def parse_args():
0010 import argparse
0011 parser = argparse.ArgumentParser()
0012 parser.add_argument("--target", type=str, choices=["cand", "gen"], help="Regress to PFCandidates or GenParticles", default="cand")
0013 parser.add_argument("--datapath", type=str, required=True, help="Input data path")
0014 parser.add_argument("--num-files-per-tfr", type=int, default=100, help="Number of pickle files to merge to one TFRecord file")
0015 args = parser.parse_args()
0016 return args
0017
0018 def chunks(lst, n):
0019 """Yield successive n-sized chunks from lst."""
0020 for i in range(0, len(lst), n):
0021 yield lst[i:i + n]
0022
0023
0024 def _bytes_feature(value):
0025 """Returns a bytes_list from a string / byte."""
0026 if isinstance(value, type(tf.constant(0))):
0027 value = value.numpy()
0028 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
0029
0030 def _parse_tfr_element(element):
0031 parse_dic = {
0032 'X': tf.io.FixedLenFeature([], tf.string),
0033 'y': tf.io.FixedLenFeature([], tf.string),
0034 'w': tf.io.FixedLenFeature([], tf.string),
0035
0036
0037
0038 }
0039 example_message = tf.io.parse_single_example(element, parse_dic)
0040
0041 X = example_message['X']
0042 arr_X = tf.io.parse_tensor(X, out_type=tf.float32)
0043 y = example_message['y']
0044 arr_y = tf.io.parse_tensor(y, out_type=tf.float32)
0045 w = example_message['w']
0046 arr_w = tf.io.parse_tensor(w, out_type=tf.float32)
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056 arr_X.set_shape(tf.TensorShape((None, 15)))
0057 arr_y.set_shape(tf.TensorShape((None, 5)))
0058 arr_w.set_shape(tf.TensorShape((None, )))
0059
0060
0061
0062 return arr_X, arr_y, arr_w
0063
0064 def serialize_X_y_w(writer, X, y, w):
0065 feature = {
0066 'X': _bytes_feature(tf.io.serialize_tensor(X)),
0067 'y': _bytes_feature(tf.io.serialize_tensor(y)),
0068 'w': _bytes_feature(tf.io.serialize_tensor(w)),
0069
0070
0071
0072 }
0073 sample = tf.train.Example(features=tf.train.Features(feature=feature))
0074 writer.write(sample.SerializeToString())
0075
0076 def serialize_chunk(args):
0077 path, files, ichunk, target = args
0078 print(path, len(files), ichunk, target)
0079 out_filename = os.path.join(path, "chunk_{}.tfrecords".format(ichunk))
0080 writer = tf.io.TFRecordWriter(out_filename)
0081 Xs = []
0082 ys = []
0083 ws = []
0084 dms = []
0085
0086 for fi in files:
0087 print(fi)
0088 X, y, ycand = load_one_file(fi)
0089
0090 Xs += X
0091 if target == "cand":
0092 ys += ycand
0093 elif target == "gen":
0094 ys += y
0095 else:
0096 raise Exception("Unknown target")
0097
0098
0099
0100 uniq_vals, uniq_counts = np.unique(np.concatenate([y[:, 0] for y in ys]), return_counts=True)
0101 for i in range(len(ys)):
0102 w = np.ones(len(ys[i]), dtype=np.float32)
0103 for uv, uc in zip(uniq_vals, uniq_counts):
0104 w[ys[i][:, 0]==uv] = uc
0105 ws += [w]
0106
0107 for X, y, w in zip(Xs, ys, ws):
0108
0109 serialize_X_y_w(writer, X, y, w)
0110
0111 writer.close()
0112
0113 if __name__ == "__main__":
0114 args = parse_args()
0115
0116
0117 datapath = args.datapath
0118
0119 filelist = sorted(glob.glob("{}/raw/*.pkl".format(datapath)))
0120 print("found {} files".format(len(filelist)))
0121 assert(len(filelist) > 0)
0122
0123 outpath = "{}/tfr/{}".format(datapath, args.target)
0124
0125 if not os.path.isdir(outpath):
0126 os.makedirs(outpath)
0127
0128 pars = []
0129 for ichunk, files in enumerate(chunks(filelist, args.num_files_per_tfr)):
0130 pars += [(outpath, files, ichunk, args.target)]
0131 assert(len(pars) > 0)
0132
0133
0134 for par in pars:
0135 serialize_chunk(par)
0136
0137
0138 tfr_dataset = tf.data.TFRecordDataset(glob.glob(outpath + "/*.tfrecords"))
0139 dataset = tfr_dataset.map(_parse_tfr_element)
0140 num_ev = 0
0141 num_particles = 0
0142 for X, y, w in dataset:
0143 num_ev += 1
0144 num_particles += len(X)
0145 assert(num_ev > 0)
0146 print("Created TFRecords dataset in {} with {} events, {} particles".format(
0147 datapath, num_ev, num_particles))