Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-02-14 23:30:56

0001 import sys
0002 import pickle
0003 import networkx as nx
0004 import numpy as np
0005 #import numba
0006 import os
0007 import uproot
0008 import uproot_methods
0009 import math
0010 
0011 import matplotlib
0012 matplotlib.use("Agg")
0013 import matplotlib.pyplot as plt
0014 
0015 import scipy
0016 import scipy.sparse
0017 from networkx.readwrite import json_graph
0018 from networkx.drawing.nx_pydot import graphviz_layout
0019 
0020 map_candid_to_pdgid = {
0021     0: [0],
0022     211: [211, 2212, 321, -3112, 3222, -3312, -3334],
0023     -211: [-211, -2212, -321, 3112, -3222, 3312, 3334],
0024     130: [111, 130, 2112, -2112, 310, 3122, -3122, 3322, -3322],
0025     22: [22],
0026     11: [11],
0027     -11: [-11],
0028      13: [13],
0029      -13: [-13]
0030 }
0031 
0032 map_pdgid_to_candid = {}
0033 
0034 for candid, pdgids in map_candid_to_pdgid.items():
0035     for p in pdgids:
0036         map_pdgid_to_candid[p] = candid
0037 
0038 #@numba.njit
0039 def get_charge(pid):
0040     abs_pid = abs(pid)
0041     if pid == 130 or pid == 22 or pid == 1 or pid == 2:
0042         return 0.0
0043     #13: mu-, 11: e-
0044     elif abs_pid == 13 or abs_pid == 11:
0045         return -math.copysign(1.0, pid)
0046     #211: pi+
0047     elif abs_pid == 211:
0048         return math.copysign(1.0, pid)
0049 
0050 def save_ego_graph(g, node, radius=4, undirected=False):
0051     sg = nx.ego_graph(g, node, radius, undirected=undirected).reverse()
0052 
0053     #remove BREM PFElements from plotting 
0054     nodes_to_remove = [n for n in sg.nodes if (n[0]=="elem" and sg.nodes[n]["typ"] in [7,])]
0055     sg.remove_nodes_from(nodes_to_remove)
0056 
0057     fig = plt.figure(figsize=(2*len(sg.nodes)+2, 10))
0058     sg_pos = graphviz_layout(sg, prog='dot')
0059     
0060     edge_labels = {}
0061     for e in sg.edges:
0062         if e[1][0] == "elem" and not (sg.nodes[e[1]]["typ"] in [1,10]):
0063             edge_labels[e] = "{:.2f} GeV".format(sg.edges[e].get("weight", 0))
0064         else:
0065             edge_labels[e] = ""
0066     
0067     node_labels = {}
0068     for node in sg.nodes:
0069         labels = {"sc": "SimCluster", "elem": "PFElement", "tp": "TrackingParticle", "pfcand": "PFCandidate"}
0070         node_labels[node] = "[{label} {idx}] \ntype: {typ}\ne: {e:.4f} GeV\npt: {pt:.4f} GeV\neta: {eta:.4f}\nphi: {phi:.4f}\nc/p: {children}/{parents}".format(
0071             label=labels[node[0]], idx=node[1], **sg.nodes[node])
0072         tp = sg.nodes[node]["typ"]
0073             
0074     nx.draw_networkx(sg, pos=sg_pos, node_shape=".", node_color="grey", edge_color="grey", node_size=0, alpha=0.5, labels={})
0075     nx.draw_networkx_labels(sg, pos=sg_pos, labels=node_labels)
0076     nx.draw_networkx_edge_labels(sg, pos=sg_pos, edge_labels=edge_labels);
0077     plt.tight_layout()
0078     plt.axis("off")
0079 
0080     return fig
0081 
0082 def draw_event(g):
0083     pos = {}
0084     for node in g.nodes:
0085         pos[node] = (g.nodes[node]["eta"], g.nodes[node]["phi"])
0086 
0087     fig = plt.figure(figsize=(10,10))
0088     
0089     nodes_to_draw = [n for n in g.nodes if n[0]=="elem"]
0090     nx.draw_networkx(g, pos=pos, with_labels=False, node_size=5, nodelist=nodes_to_draw, edgelist=[], node_color="red", node_shape="s", alpha=0.5)
0091     
0092     nodes_to_draw = [n for n in g.nodes if n[0]=="pfcand"]
0093     nx.draw_networkx(g, pos=pos, with_labels=False, node_size=10, nodelist=nodes_to_draw, edgelist=[], node_color="green", node_shape="x", alpha=0.5)
0094     
0095     nodes_to_draw = [n for n in g.nodes if (n[0]=="sc" or n[0]=="tp")]
0096     nx.draw_networkx(g, pos=pos, with_labels=False, node_size=1, nodelist=nodes_to_draw, edgelist=[], node_color="blue", node_shape=".", alpha=0.5)
0097    
0098     #draw edges between genparticles and elements
0099     edges_to_draw = [e for e in g.edges if e[0] in nodes_to_draw]
0100     nx.draw_networkx_edges(g, pos, edgelist=edges_to_draw, arrows=False, alpha=0.1)
0101     
0102     plt.xlim(-6,6)
0103     plt.ylim(-4,4)
0104     plt.tight_layout()
0105     plt.axis("on")
0106     return fig
0107 
0108 def cleanup_graph(g, edge_energy_threshold=0.01, edge_fraction_threshold=0.05, genparticle_energy_threshold=0.2, genparticle_pt_threshold=0.01):
0109     g = g.copy()
0110 
0111     edges_to_remove = []
0112     nodes_to_remove = []
0113 
0114     #remove edges that contribute little
0115     for edge in g.edges:
0116         if edge[0][0] == "sc":
0117             w = g.edges[edge]["weight"]
0118             if w < edge_energy_threshold:
0119                 edges_to_remove += [edge]
0120         if edge[0][0] == "sc" or edge[0][0] == "tp":
0121             if g.nodes[edge[1]]["typ"] == 10:
0122                 g.edges[edge]["weight"] = 1.0
0123                 
0124     #remove genparticles below energy threshold
0125     for node in g.nodes:
0126         if (node[0]=="sc" or node[0]=="tp") and g.nodes[node]["e"] < genparticle_energy_threshold:
0127             nodes_to_remove += [node]
0128     
0129     g.remove_edges_from(edges_to_remove)
0130     g.remove_nodes_from(nodes_to_remove)
0131     
0132     rg = g.reverse()
0133     
0134     #for each element, remove the incoming edges that contribute less than 5% of the total
0135     edges_to_remove = []
0136     nodes_to_remove = []
0137     for node in rg.nodes:
0138         if node[0] == "elem":
0139             ##check for generator pairs with very similar eta,phi, which can come from gamma->e+ e-
0140             #if rg.nodes[node]["typ"] == 4:
0141             #    by_eta_phi = {}
0142             #    for neigh in rg.neighbors(node):
0143             #        k = (round(rg.nodes[neigh]["eta"], 2), round(rg.nodes[neigh]["phi"], 2))
0144             #        if not k in by_eta_phi:
0145             #            by_eta_phi[k] = []
0146             #        by_eta_phi[k] += [neigh]
0147     
0148             #    for k in by_eta_phi:
0149             #        #if there were genparticles with the same eta,phi, assume it was a photon with nuclear interaction
0150             #        if len(by_eta_phi[k])>=2:
0151             #            #print(by_eta_phi[k][0])
0152             #            rg.nodes[by_eta_phi[k][0]]["typ"] = 22
0153             #            rg.nodes[by_eta_phi[k][0]]["e"] += sum(rg.nodes[n]["e"] for n in by_eta_phi[k][1:])
0154             #            rg.nodes[by_eta_phi[k][0]]["pt"] = 0 #fixme
0155             #            nodes_to_remove += by_eta_phi[k][1:]
0156             
0157             #remove links that don't contribute above a threshold 
0158             ew = [((node, node2), rg.edges[node, node2]["weight"]) for node2 in rg.neighbors(node)]
0159             ew = filter(lambda x: x[1] != 1.0, ew)
0160             ew = sorted(ew, key=lambda x: x[1], reverse=True)
0161             if len(ew) > 1:
0162                 max_in = ew[0][1]
0163                 for e, w in ew[1:]:
0164                     if w / max_in < edge_fraction_threshold:
0165                         edges_to_remove += [e]
0166     
0167     rg.remove_edges_from(edges_to_remove)        
0168     rg.remove_nodes_from(nodes_to_remove)        
0169     g = rg.reverse()
0170     
0171     #remove genparticles not linked to any elements
0172     nodes_to_remove = []
0173     for node in g.nodes:
0174         if node[0]=="sc" or node[0]=="tp":
0175             deg = g.degree[node]
0176             if deg==0:
0177                 nodes_to_remove += [node]
0178     g.remove_nodes_from(nodes_to_remove)
0179    
0180     #compute number of children and parents, save on node for visualization 
0181     for node in g.nodes:
0182         g.nodes[node]["children"] = len(list(g.neighbors(node)))
0183     
0184     rg = g.reverse()
0185     
0186     for node in rg.nodes:
0187         g.nodes[node]["parents"] = len(list(rg.neighbors(node)))
0188         rg.nodes[node]["parents"] = len(list(rg.neighbors(node)))
0189 
0190     return g
0191 
0192 def prepare_normalized_table(g, genparticle_energy_threshold=0.2):
0193     rg = g.reverse()
0194 
0195     all_genparticles = []
0196     all_elements = []
0197     all_pfcandidates = []
0198     for node in rg.nodes:
0199         if node[0] == "elem":
0200             all_elements += [node]
0201             for parent in rg.neighbors(node):
0202                 all_genparticles += [parent]
0203         elif node[0] == "pfcand":
0204             all_pfcandidates += [node]
0205     all_genparticles = list(set(all_genparticles))
0206     all_elements = sorted(all_elements)
0207 
0208     #assign genparticles in reverse pt order uniquely to best element
0209     elem_to_gp = {}
0210     unmatched_gp = []
0211     for gp in sorted(all_genparticles, key=lambda x: g.nodes[x]["pt"], reverse=True):
0212         elems = [e for e in g.neighbors(gp)]
0213 
0214         #don't assign any genparticle to these elements (PS, BREM, SC)
0215         elems = [e for e in elems if not (g.nodes[e]["typ"] in [2,3,7,10])]
0216 
0217         #sort elements by energy from genparticle
0218         elems_sorted = sorted([(g.edges[gp, e]["weight"], e) for e in elems], key=lambda x: x[0], reverse=True)
0219         
0220         if len(elems_sorted) == 0:
0221             continue
0222 
0223         chosen_elem = None
0224         for _, elem in elems_sorted:
0225             if not (elem in elem_to_gp):
0226                 chosen_elem = elem
0227                 elem_to_gp[elem] = []
0228                 break
0229         if chosen_elem is None:
0230             unmatched_gp += [gp]
0231         else:
0232             elem_to_gp[elem] += [gp]
0233 
0234     #assign unmatched genparticles to best element, allowing for overlaps
0235     for gp in sorted(unmatched_gp, key=lambda x: g.nodes[x]["pt"], reverse=True):
0236         elems = [e for e in g.neighbors(gp)]
0237         #we don't want to assign any genparticles to PS, BREM or SC - links are not reliable
0238         elems = [e for e in elems if not (g.nodes[e]["typ"] in [2,3,7,10])]
0239         elems_sorted = sorted([(g.edges[gp, e]["weight"], e) for e in elems], key=lambda x: x[0], reverse=True)
0240         _, elem = elems_sorted[0]
0241         elem_to_gp[elem] += [gp]
0242  
0243     unmatched_cand = [] 
0244     elem_to_cand = {} 
0245     for cand in sorted(all_pfcandidates, key=lambda x: g.nodes[x]["pt"], reverse=True):
0246         tp = g.nodes[cand]["typ"]
0247         neighbors = list(rg.neighbors(cand))
0248 
0249         chosen_elem = None
0250 
0251         #Pions and muons will be assigned to tracks
0252         if abs(tp) == 211 or abs(tp) == 13:
0253             for elem in neighbors:
0254                 tp_neighbor = g.nodes[elem]["typ"]
0255                 if tp_neighbor == 1:
0256                     if not (elem in elem_to_cand):
0257                         chosen_elem = elem
0258                         elem_to_cand[elem] = cand
0259                         break
0260         #other particles will be assigned to the highest-energy cluster (ECAL, HCAL, HFEM, HFHAD, SC)
0261         else:
0262             neighbors = [n for n in neighbors if g.nodes[n]["typ"] in [4,5,8,9,10]]
0263             sorted_neighbors = sorted(neighbors, key=lambda x: g.nodes[x]["e"], reverse=True)
0264             for elem in sorted_neighbors:
0265                 if not (elem in elem_to_cand):
0266                     chosen_elem = elem
0267                     elem_to_cand[elem] = cand
0268                     break
0269 
0270         if chosen_elem is None:
0271             print("unmatched candidate {}, {}".format(cand, g.nodes[cand]))
0272             unmatched_cand += [cand]
0273 
0274     elem_branches = [
0275         "typ", "pt", "eta", "phi", "e",
0276         "layer", "depth", "charge", "trajpoint", 
0277         "eta_ecal", "phi_ecal", "eta_hcal", "phi_hcal", "muon_dt_hits", "muon_csc_hits"
0278     ]
0279     target_branches = ["typ", "pt", "eta", "phi", "e", "px", "py", "pz", "charge"]
0280 
0281     Xelem = np.recarray((len(all_elements),), dtype=[(name, np.float32) for name in elem_branches])
0282     Xelem.fill(0.0)
0283     ygen = np.recarray((len(all_elements),), dtype=[(name, np.float32) for name in target_branches])
0284     ygen.fill(0.0)
0285     ycand = np.recarray((len(all_elements),), dtype=[(name, np.float32) for name in target_branches])
0286     ycand.fill(0.0)
0287  
0288     #find which elements should be linked together in the output when regressing to PFCandidates or GenParticles
0289     graph_elem_cand = nx.Graph()  
0290     graph_elem_gen = nx.Graph()  
0291     for elem in all_elements:
0292         graph_elem_cand.add_node(elem) 
0293         graph_elem_gen.add_node(elem) 
0294  
0295     for cand in all_pfcandidates:
0296         for elem1 in rg.neighbors(cand):
0297             for elem2 in rg.neighbors(cand):
0298                 if (elem1 != elem2):
0299                     graph_elem_cand.add_edge(elem1, elem2)
0300 
0301     for gp in all_genparticles:
0302         for elem1 in g.neighbors(gp):
0303             for elem2 in g.neighbors(gp):
0304                 if (elem1 != elem2):
0305                     graph_elem_gen.add_edge(elem1, elem2)
0306  
0307     for ielem, elem in enumerate(all_elements):
0308         elem_type = g.nodes[elem]["typ"]
0309         elem_eta = g.nodes[elem]["eta"]
0310         genparticles = sorted(elem_to_gp.get(elem, []), key=lambda x: g.nodes[x]["e"], reverse=True)
0311         genparticles = [gp for gp in genparticles if g.nodes[gp]["e"] > genparticle_energy_threshold]
0312         candidate = elem_to_cand.get(elem, None)
0313        
0314         lv = uproot_methods.TLorentzVector(0, 0, 0, 0)
0315        
0316         pid = 0
0317         if len(genparticles) > 0:
0318             pid = map_pdgid_to_candid.get(g.nodes[genparticles[0]]["typ"], 0)
0319 
0320         for gp in genparticles:
0321             try:
0322                 lv += uproot_methods.TLorentzVector.from_ptetaphie(
0323                     g.nodes[gp]["pt"],
0324                     g.nodes[gp]["eta"],
0325                     g.nodes[gp]["phi"],
0326                     g.nodes[gp]["e"]
0327                 )
0328             except OverflowError:
0329                 lv += uproot_methods.TLorentzVector.from_ptetaphie(
0330                     g.nodes[gp]["pt"],
0331                     np.nan,
0332                     g.nodes[gp]["phi"],
0333                     g.nodes[gp]["e"]
0334                 )
0335 
0336         if len(genparticles) > 0:
0337             if abs(elem_eta) > 3.0:
0338                 #HFHAD -> always produce hadronic candidate
0339                 if elem_type == 9:
0340                     pid = 1
0341                 #HFEM -> decide based on pid
0342                 elif elem_type == 8:
0343                     if abs(pid) in [11, 22]:
0344                         pid = 2 #produce EM candidate 
0345                     else:
0346                         pid = 1 #produce hadronic
0347 
0348             #remap PID in case of HCAL cluster
0349             if elem_type == 5 and (pid == 22 or abs(pid) == 11):
0350                 pid = 130
0351 
0352         #reproduce ROOT.TLorentzVector behavior (https://root.cern.ch/doc/master/TVector3_8cxx_source.html#l00320)
0353         try:
0354             eta = lv.eta
0355         except ZeroDivisionError:
0356             eta = np.sign(lv.z)*10e10
0357         
0358         gp = {
0359             "pt": lv.pt, "eta": eta, "phi": lv.phi, "e": lv.energy, "typ": pid, "px": lv.x, "py": lv.y, "pz": lv.z, "charge": get_charge(pid)
0360         }
0361 
0362         for j in range(len(elem_branches)):
0363             Xelem[elem_branches[j]][ielem] = g.nodes[elem][elem_branches[j]]
0364 
0365         for j in range(len(target_branches)):
0366             if not (candidate is None):
0367                 ycand[target_branches[j]][ielem] = g.nodes[candidate][target_branches[j]]
0368             ygen[target_branches[j]][ielem] = gp[target_branches[j]]
0369 
0370     dm_elem_cand = scipy.sparse.coo_matrix(nx.to_numpy_matrix(graph_elem_cand, nodelist=all_elements))
0371     dm_elem_gen = scipy.sparse.coo_matrix(nx.to_numpy_matrix(graph_elem_gen, nodelist=all_elements))
0372     return Xelem, ycand, ygen, dm_elem_cand, dm_elem_gen
0373 #end of prepare_normalized_table
0374 
0375 
0376 def process(args):
0377     infile = args.input
0378     outpath = os.path.join(args.outpath, os.path.basename(infile).split(".")[0])
0379     tf = uproot.open(infile)
0380     tt = tf["ana/pftree"]
0381 
0382     events_to_process = [i for i in range(tt.numentries)] 
0383     if not (args.event is None):
0384         events_to_process = [args.event]
0385 
0386     all_data = []
0387     ifile = 0
0388     for iev in events_to_process:
0389         print("processing event {}".format(iev))
0390 
0391         ev = tt.arrays(flatten=True,entrystart=iev,entrystop=iev+1)
0392         
0393         element_type = ev[b'element_type']
0394         element_pt = ev[b'element_pt']
0395         element_e = ev[b'element_energy']
0396         element_eta = ev[b'element_eta']
0397         element_phi = ev[b'element_phi']
0398         element_eta_ecal = ev[b'element_eta_ecal']
0399         element_phi_ecal = ev[b'element_phi_ecal']
0400         element_eta_hcal = ev[b'element_eta_hcal']
0401         element_phi_hcal = ev[b'element_phi_hcal']
0402         element_trajpoint = ev[b'element_trajpoint']
0403         element_layer = ev[b'element_layer']
0404         element_charge = ev[b'element_charge']
0405         element_depth = ev[b'element_depth']
0406         element_deltap = ev[b'element_deltap']
0407         element_sigmadeltap = ev[b'element_sigmadeltap']
0408         element_px = ev[b'element_px']
0409         element_py = ev[b'element_py']
0410         element_pz = ev[b'element_pz']
0411         element_muon_dt_hits = ev[b'element_muon_dt_hits']
0412         element_muon_csc_hits = ev[b'element_muon_csc_hits']
0413 
0414         trackingparticle_pid = ev[b'trackingparticle_pid']
0415         trackingparticle_pt = ev[b'trackingparticle_pt']
0416         trackingparticle_e = ev[b'trackingparticle_energy']
0417         trackingparticle_eta = ev[b'trackingparticle_eta']
0418         trackingparticle_phi = ev[b'trackingparticle_phi']
0419         trackingparticle_phi = ev[b'trackingparticle_phi']
0420         trackingparticle_px = ev[b'trackingparticle_px']
0421         trackingparticle_py = ev[b'trackingparticle_py']
0422         trackingparticle_pz = ev[b'trackingparticle_pz']
0423 
0424         simcluster_pid = ev[b'simcluster_pid']
0425         simcluster_pt = ev[b'simcluster_pt']
0426         simcluster_e = ev[b'simcluster_energy']
0427         simcluster_eta = ev[b'simcluster_eta']
0428         simcluster_phi = ev[b'simcluster_phi']
0429         simcluster_px = ev[b'simcluster_px']
0430         simcluster_py = ev[b'simcluster_py']
0431         simcluster_pz = ev[b'simcluster_pz']
0432 
0433         simcluster_idx_trackingparticle = ev[b'simcluster_idx_trackingparticle']
0434         pfcandidate_pdgid = ev[b'pfcandidate_pdgid']
0435         pfcandidate_pt = ev[b'pfcandidate_pt']
0436         pfcandidate_e = ev[b'pfcandidate_energy']
0437         pfcandidate_eta = ev[b'pfcandidate_eta']
0438         pfcandidate_phi = ev[b'pfcandidate_phi']
0439         pfcandidate_px = ev[b'pfcandidate_px']
0440         pfcandidate_py = ev[b'pfcandidate_py']
0441         pfcandidate_pz = ev[b'pfcandidate_pz']
0442 
0443         g = nx.DiGraph()
0444         for iobj in range(len(element_type)):
0445             g.add_node(("elem", iobj),
0446                 typ=element_type[iobj],
0447                 pt=element_pt[iobj],
0448                 e=element_e[iobj],
0449                 eta=element_eta[iobj],
0450                 phi=element_phi[iobj],
0451                 eta_ecal=element_eta_ecal[iobj],
0452                 phi_ecal=element_phi_ecal[iobj],
0453                 eta_hcal=element_eta_hcal[iobj],
0454                 phi_hcal=element_phi_hcal[iobj],
0455                 trajpoint=element_trajpoint[iobj],
0456                 layer=element_layer[iobj],
0457                 charge=element_charge[iobj],
0458                 depth=element_depth[iobj],
0459                 deltap=element_deltap[iobj],
0460                 sigmadeltap=element_sigmadeltap[iobj],
0461                 px=element_px[iobj],
0462                 py=element_py[iobj],
0463                 pz=element_pz[iobj],
0464                 muon_dt_hits=element_muon_dt_hits[iobj],
0465                 muon_csc_hits=element_muon_csc_hits[iobj],
0466             )
0467         for iobj in range(len(trackingparticle_pid)):
0468             g.add_node(("tp", iobj),
0469                 typ=trackingparticle_pid[iobj],
0470                 pt=trackingparticle_pt[iobj],
0471                 e=trackingparticle_e[iobj],
0472                 eta=trackingparticle_eta[iobj],
0473                 phi=trackingparticle_phi[iobj],
0474                 px=trackingparticle_px[iobj],
0475                 py=trackingparticle_py[iobj],
0476                 pz=trackingparticle_pz[iobj],
0477             )
0478         for iobj in range(len(simcluster_pid)):
0479             g.add_node(("sc", iobj),
0480                 typ=simcluster_pid[iobj],
0481                 pt=simcluster_pt[iobj],
0482                 e=simcluster_e[iobj],
0483                 eta=simcluster_eta[iobj],
0484                 phi=simcluster_phi[iobj],
0485                 px=simcluster_px[iobj],
0486                 py=simcluster_py[iobj],
0487                 pz=simcluster_pz[iobj],
0488             )
0489 
0490         trackingparticle_to_element_first = ev[b'trackingparticle_to_element.first']
0491         trackingparticle_to_element_second = ev[b'trackingparticle_to_element.second']
0492         #for trackingparticles associated to elements, set a very high edge weight
0493         for tp, elem in zip(trackingparticle_to_element_first, trackingparticle_to_element_second):
0494             g.add_edge(("tp", tp), ("elem", elem), weight=99999.0)
0495  
0496         simcluster_to_element_first = ev[b'simcluster_to_element.first']
0497         simcluster_to_element_second = ev[b'simcluster_to_element.second']
0498         simcluster_to_element_cmp = ev[b'simcluster_to_element_cmp']
0499         for sc, elem, c in zip(simcluster_to_element_first, simcluster_to_element_second, simcluster_to_element_cmp):
0500             g.add_edge(("sc", sc), ("elem", elem), weight=c)
0501 
0502         print("contracting nodes: trackingparticle to simcluster")
0503         nodes_to_remove = []
0504         for idx_sc, idx_tp in enumerate(simcluster_idx_trackingparticle):
0505             if idx_tp != -1:
0506                 for elem in g.neighbors(("sc", idx_sc)):
0507                     g.add_edge(("tp", idx_tp), elem, weight=g.edges[("sc", idx_sc), elem]["weight"]) 
0508                 g.nodes[("tp", idx_tp)]["idx_sc"] = idx_sc   
0509                 nodes_to_remove += [("sc", idx_sc)]
0510         g.remove_nodes_from(nodes_to_remove)
0511 
0512         for iobj in range(len(pfcandidate_pdgid)):
0513             g.add_node(("pfcand", iobj),
0514                 typ=pfcandidate_pdgid[iobj],
0515                 pt=pfcandidate_pt[iobj],
0516                 e=pfcandidate_e[iobj],
0517                 eta=pfcandidate_eta[iobj],
0518                 phi=pfcandidate_phi[iobj],
0519                 px=pfcandidate_px[iobj],
0520                 py=pfcandidate_py[iobj],
0521                 pz=pfcandidate_pz[iobj],
0522                 charge=get_charge(pfcandidate_pdgid[iobj]),
0523             )
0524 
0525         element_to_candidate_first = ev[b'element_to_candidate.first']
0526         element_to_candidate_second = ev[b'element_to_candidate.second']
0527         for elem, pfcand in zip(element_to_candidate_first, element_to_candidate_second):
0528             g.add_edge(("elem", elem), ("pfcand", pfcand), weight=1.0)
0529         print("Graph created: {} nodes, {} edges".format(len(g.nodes), len(g.edges)))
0530  
0531         g = cleanup_graph(g)
0532         rg = g.reverse()
0533 
0534         #make tree visualizations for PFCandidates
0535         ncand = 0 
0536         for node in sorted(filter(lambda x: x[0]=="pfcand", g.nodes), key=lambda x: g.nodes[x]["pt"], reverse=True):
0537             if ncand < args.plot_candidates:
0538                 print(node, g.nodes[node]["pt"])
0539                 fig = save_ego_graph(rg, node, 3, False)
0540                 plt.savefig(outpath + "_ev_{}_cand_{}_idx_{}.pdf".format(iev, ncand, node[1]), bbox_inches="tight")
0541                 plt.clf()
0542                 del fig
0543             ncand += 1
0544 
0545         #fig = draw_event(g)
0546         #plt.savefig(outpath + "_ev_{}.pdf".format(iev))
0547         #plt.clf()
0548 
0549         #do one-to-one associations
0550         Xelem, ycand, ygen, dm_elem_cand, dm_elem_gen = prepare_normalized_table(g)
0551         #dm = prepare_elem_distance_matrix(ev)
0552         data = {}
0553 
0554         if args.save_normalized_table:
0555             data = {
0556                 "Xelem": Xelem,
0557                 "ycand": ycand,
0558                 "ygen": ygen,
0559                 #"dm": dm,
0560                 "dm_elem_cand": dm_elem_cand,
0561                 "dm_elem_gen": dm_elem_gen
0562             }
0563 
0564         if args.save_full_graph:
0565             data["full_graph"] = g
0566 
0567         all_data += [data]
0568 
0569         if args.events_per_file > 0:
0570             if len(all_data) == args.events_per_file:
0571                 print(outpath + "_{}.pkl".format(ifile))
0572                 with open(outpath + "_{}.pkl".format(ifile), "wb") as fi:
0573                     pickle.dump(all_data, fi)
0574                 ifile += 1
0575                 all_data = []
0576 
0577     if args.events_per_file == -1:
0578         print(outpath)
0579         with open(outpath + ".pkl", "wb") as fi:
0580             pickle.dump(all_data, fi)
0581 
0582 def parse_args():
0583     import argparse
0584     parser = argparse.ArgumentParser()
0585     parser.add_argument("--input", type=str, help="Input file from PFAnalysis", required=True)
0586     parser.add_argument("--event", type=int, default=None, help="event index to process, omit to process all")
0587     parser.add_argument("--outpath", type=str, default="raw", help="output path")
0588     parser.add_argument("--plot-candidates", type=int, default=0, help="number of PFCandidates to plot as trees in pt-descending order")
0589     parser.add_argument("--events-per-file", type=int, default=-1, help="number of events per output file, -1 for all")
0590     parser.add_argument("--save-full-graph", action="store_true", help="save the full event graph")
0591     parser.add_argument("--save-normalized-table", action="store_true", help="save the uniquely identified table")
0592     args = parser.parse_args()
0593     return args
0594 
0595 if __name__ == "__main__":
0596     args = parse_args()
0597     process(args)
0598