Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:27:33

0001 import numpy as np
0002 import glob
0003 import multiprocessing
0004 import os
0005 
0006 import tensorflow as tf
0007 from tf_model import load_one_file
0008 
0009 def parse_args():
0010     import argparse
0011     parser = argparse.ArgumentParser()
0012     parser.add_argument("--target", type=str, choices=["cand", "gen"], help="Regress to PFCandidates or GenParticles", default="cand")
0013     parser.add_argument("--datapath", type=str, required=True, help="Input data path")
0014     parser.add_argument("--num-files-per-tfr", type=int, default=100, help="Number of pickle files to merge to one TFRecord file")
0015     args = parser.parse_args()
0016     return args
0017 
0018 def chunks(lst, n):
0019     """Yield successive n-sized chunks from lst."""
0020     for i in range(0, len(lst), n):
0021         yield lst[i:i + n]
0022 
0023 #https://stackoverflow.com/questions/47861084/how-to-store-numpy-arrays-as-tfrecord
0024 def _bytes_feature(value):
0025     """Returns a bytes_list from a string / byte."""
0026     if isinstance(value, type(tf.constant(0))): # if value ist tensor
0027         value = value.numpy() # get value of tensor
0028     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
0029 
0030 def _parse_tfr_element(element):
0031     parse_dic = {
0032         'X': tf.io.FixedLenFeature([], tf.string),
0033         'y': tf.io.FixedLenFeature([], tf.string),
0034         'w': tf.io.FixedLenFeature([], tf.string),
0035         #'dm_row': tf.io.FixedLenFeature([], tf.string),
0036         #'dm_col': tf.io.FixedLenFeature([], tf.string),
0037         #'dm_data': tf.io.FixedLenFeature([], tf.string),
0038     }
0039     example_message = tf.io.parse_single_example(element, parse_dic)
0040 
0041     X = example_message['X']
0042     arr_X = tf.io.parse_tensor(X, out_type=tf.float32)
0043     y = example_message['y']
0044     arr_y = tf.io.parse_tensor(y, out_type=tf.float32)
0045     w = example_message['w']
0046     arr_w = tf.io.parse_tensor(w, out_type=tf.float32)
0047     
0048     #dm_row = example_message['dm_row']
0049     #arr_dm_row = tf.io.parse_tensor(dm_row, out_type=tf.int64)
0050     #dm_col = example_message['dm_col']
0051     #arr_dm_col = tf.io.parse_tensor(dm_col, out_type=tf.int64)
0052     #dm_data = example_message['dm_data']
0053     #arr_dm_data = tf.io.parse_tensor(dm_data, out_type=tf.float32)
0054 
0055     #https://github.com/tensorflow/tensorflow/issues/24520#issuecomment-577325475
0056     arr_X.set_shape(tf.TensorShape((None, 15)))
0057     arr_y.set_shape(tf.TensorShape((None, 5)))
0058     arr_w.set_shape(tf.TensorShape((None, )))
0059     #inds = tf.stack([arr_dm_row, arr_dm_col], axis=-1)
0060     #dm_sparse = tf.SparseTensor(values=arr_dm_data, indices=inds, dense_shape=[tf.shape(arr_X)[0], tf.shape(arr_X)[0]])
0061 
0062     return arr_X, arr_y, arr_w
0063 
0064 def serialize_X_y_w(writer, X, y, w):
0065     feature = {
0066         'X': _bytes_feature(tf.io.serialize_tensor(X)),
0067         'y': _bytes_feature(tf.io.serialize_tensor(y)),
0068         'w': _bytes_feature(tf.io.serialize_tensor(w)),
0069         #'dm_row': _bytes_feature(tf.io.serialize_tensor(np.array(dm.row, np.int64))),
0070         #'dm_col': _bytes_feature(tf.io.serialize_tensor(np.array(dm.col, np.int64))),
0071         #'dm_data': _bytes_feature(tf.io.serialize_tensor(dm.data)),
0072     }
0073     sample = tf.train.Example(features=tf.train.Features(feature=feature))
0074     writer.write(sample.SerializeToString())
0075 
0076 def serialize_chunk(args):
0077     path, files, ichunk, target = args
0078     print(path, len(files), ichunk, target)
0079     out_filename = os.path.join(path, "chunk_{}.tfrecords".format(ichunk))
0080     writer = tf.io.TFRecordWriter(out_filename)
0081     Xs = []
0082     ys = []
0083     ws = []
0084     dms = []
0085 
0086     for fi in files:
0087         print(fi)
0088         X, y, ycand = load_one_file(fi)
0089 
0090         Xs += X
0091         if target == "cand":
0092             ys += ycand
0093         elif target == "gen":
0094             ys += y
0095         else:
0096             raise Exception("Unknown target")
0097 
0098     #set weights for each sample to be equal to the number of samples of this type
0099     #in the training script, this can be used to compute either inverse or class-balanced weights
0100     uniq_vals, uniq_counts = np.unique(np.concatenate([y[:, 0] for y in ys]), return_counts=True)
0101     for i in range(len(ys)):
0102         w = np.ones(len(ys[i]), dtype=np.float32)
0103         for uv, uc in zip(uniq_vals, uniq_counts):
0104             w[ys[i][:, 0]==uv] = uc
0105         ws += [w]
0106 
0107     for X, y, w in zip(Xs, ys, ws):
0108         #print("serializing", X.shape, y.shape, w.shape)
0109         serialize_X_y_w(writer, X, y, w)
0110 
0111     writer.close()
0112 
0113 if __name__ == "__main__":
0114     args = parse_args()
0115     #tf.config.experimental_run_functions_eagerly(True)
0116 
0117     datapath = args.datapath
0118 
0119     filelist = sorted(glob.glob("{}/raw/*.pkl".format(datapath)))
0120     print("found {} files".format(len(filelist)))
0121     assert(len(filelist) > 0)
0122     #means, stds = extract_means_stds(filelist)
0123     outpath = "{}/tfr/{}".format(datapath, args.target)
0124 
0125     if not os.path.isdir(outpath):
0126         os.makedirs(outpath)
0127 
0128     pars = []
0129     for ichunk, files in enumerate(chunks(filelist, args.num_files_per_tfr)):
0130         pars += [(outpath, files, ichunk, args.target)]
0131     assert(len(pars) > 0)
0132     #serialize_chunk(pars[0])
0133     #pool = multiprocessing.Pool(20)
0134     for par in pars:
0135         serialize_chunk(par)
0136 
0137     #Load and test the dataset 
0138     tfr_dataset = tf.data.TFRecordDataset(glob.glob(outpath + "/*.tfrecords"))
0139     dataset = tfr_dataset.map(_parse_tfr_element)
0140     num_ev = 0
0141     num_particles = 0
0142     for X, y, w in dataset:
0143         num_ev += 1
0144         num_particles += len(X)
0145     assert(num_ev > 0)
0146     print("Created TFRecords dataset in {} with {} events, {} particles".format(
0147         datapath, num_ev, num_particles))