Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:27:33

0001 import sys
0002 import pickle
0003 import networkx as nx
0004 import numpy as np
0005 import os
0006 import uproot
0007 import uproot_methods
0008 import math
0009 
0010 import matplotlib
0011 matplotlib.use("Agg")
0012 import matplotlib.pyplot as plt
0013 
0014 import scipy
0015 import scipy.sparse
0016 from networkx.readwrite import json_graph
0017 from networkx.drawing.nx_pydot import graphviz_layout
0018 
0019 map_candid_to_pdgid = {
0020     0: [0],
0021     211: [211, 2212, 321, -3112, 3222, -3312, -3334],
0022     -211: [-211, -2212, -321, 3112, -3222, 3312, 3334],
0023     130: [111, 130, 2112, -2112, 310, 3122, -3122, 3322, -3322],
0024     22: [22],
0025     11: [11],
0026     -11: [-11],
0027      13: [13],
0028      -13: [-13]
0029 }
0030 
0031 map_pdgid_to_candid = {}
0032 
0033 for candid, pdgids in map_candid_to_pdgid.items():
0034     for p in pdgids:
0035         map_pdgid_to_candid[p] = candid
0036 
0037 def get_charge(pid):
0038     abs_pid = abs(pid)
0039     if pid == 130 or pid == 22 or pid == 1 or pid == 2:
0040         return 0.0
0041     #13: mu-, 11: e-
0042     elif abs_pid == 13 or abs_pid == 11:
0043         return -math.copysign(1.0, pid)
0044     #211: pi+
0045     elif abs_pid == 211:
0046         return math.copysign(1.0, pid)
0047 
0048 def save_ego_graph(g, node, radius=4, undirected=False):
0049     sg = nx.ego_graph(g, node, radius, undirected=undirected).reverse()
0050 
0051     #remove BREM PFElements from plotting 
0052     nodes_to_remove = [n for n in sg.nodes if (n[0]=="elem" and sg.nodes[n]["typ"] in [7,])]
0053     sg.remove_nodes_from(nodes_to_remove)
0054 
0055     fig = plt.figure(figsize=(2*len(sg.nodes)+2, 10))
0056     sg_pos = graphviz_layout(sg, prog='dot')
0057     
0058     edge_labels = {}
0059     for e in sg.edges:
0060         if e[1][0] == "elem" and not (sg.nodes[e[1]]["typ"] in [1,10]):
0061             edge_labels[e] = "{:.2f} GeV".format(sg.edges[e].get("weight", 0))
0062         else:
0063             edge_labels[e] = ""
0064     
0065     node_labels = {}
0066     for node in sg.nodes:
0067         labels = {"sc": "SimCluster", "elem": "PFElement", "tp": "TrackingParticle", "pfcand": "PFCandidate"}
0068         node_labels[node] = "[{label} {idx}] \ntype: {typ}\ne: {e:.4f} GeV\npt: {pt:.4f} GeV\neta: {eta:.4f}\nphi: {phi:.4f}\nc/p: {children}/{parents}".format(
0069             label=labels[node[0]], idx=node[1], **sg.nodes[node])
0070         tp = sg.nodes[node]["typ"]
0071             
0072     nx.draw_networkx(sg, pos=sg_pos, node_shape=".", node_color="grey", edge_color="grey", node_size=0, alpha=0.5, labels={})
0073     nx.draw_networkx_labels(sg, pos=sg_pos, labels=node_labels)
0074     nx.draw_networkx_edge_labels(sg, pos=sg_pos, edge_labels=edge_labels);
0075     plt.tight_layout()
0076     plt.axis("off")
0077 
0078     return fig
0079 
0080 def draw_event(g):
0081     pos = {}
0082     for node in g.nodes:
0083         pos[node] = (g.nodes[node]["eta"], g.nodes[node]["phi"])
0084 
0085     fig = plt.figure(figsize=(10,10))
0086     
0087     nodes_to_draw = [n for n in g.nodes if n[0]=="elem"]
0088     nx.draw_networkx(g, pos=pos, with_labels=False, node_size=5, nodelist=nodes_to_draw, edgelist=[], node_color="red", node_shape="s", alpha=0.5)
0089     
0090     nodes_to_draw = [n for n in g.nodes if n[0]=="pfcand"]
0091     nx.draw_networkx(g, pos=pos, with_labels=False, node_size=10, nodelist=nodes_to_draw, edgelist=[], node_color="green", node_shape="x", alpha=0.5)
0092     
0093     nodes_to_draw = [n for n in g.nodes if (n[0]=="sc" or n[0]=="tp")]
0094     nx.draw_networkx(g, pos=pos, with_labels=False, node_size=1, nodelist=nodes_to_draw, edgelist=[], node_color="blue", node_shape=".", alpha=0.5)
0095    
0096     #draw edges between genparticles and elements
0097     edges_to_draw = [e for e in g.edges if e[0] in nodes_to_draw]
0098     nx.draw_networkx_edges(g, pos, edgelist=edges_to_draw, arrows=False, alpha=0.1)
0099     
0100     plt.xlim(-6,6)
0101     plt.ylim(-4,4)
0102     plt.tight_layout()
0103     plt.axis("on")
0104     return fig
0105 
0106 def cleanup_graph(g, edge_energy_threshold=0.01, edge_fraction_threshold=0.05, genparticle_energy_threshold=0.2, genparticle_pt_threshold=0.01):
0107     g = g.copy()
0108 
0109     edges_to_remove = []
0110     nodes_to_remove = []
0111 
0112     #remove edges that contribute little
0113     for edge in g.edges:
0114         if edge[0][0] == "sc":
0115             w = g.edges[edge]["weight"]
0116             if w < edge_energy_threshold:
0117                 edges_to_remove += [edge]
0118         if edge[0][0] == "sc" or edge[0][0] == "tp":
0119             if g.nodes[edge[1]]["typ"] == 10:
0120                 g.edges[edge]["weight"] = 1.0
0121                 
0122     #remove genparticles below energy threshold
0123     for node in g.nodes:
0124         if (node[0]=="sc" or node[0]=="tp") and g.nodes[node]["e"] < genparticle_energy_threshold:
0125             nodes_to_remove += [node]
0126     
0127     g.remove_edges_from(edges_to_remove)
0128     g.remove_nodes_from(nodes_to_remove)
0129     
0130     rg = g.reverse()
0131     
0132     #for each element, remove the incoming edges that contribute less than 5% of the total
0133     edges_to_remove = []
0134     nodes_to_remove = []
0135     for node in rg.nodes:
0136         if node[0] == "elem":
0137             ##check for generator pairs with very similar eta,phi, which can come from gamma->e+ e-
0138             #if rg.nodes[node]["typ"] == 4:
0139             #    by_eta_phi = {}
0140             #    for neigh in rg.neighbors(node):
0141             #        k = (round(rg.nodes[neigh]["eta"], 2), round(rg.nodes[neigh]["phi"], 2))
0142             #        if not k in by_eta_phi:
0143             #            by_eta_phi[k] = []
0144             #        by_eta_phi[k] += [neigh]
0145     
0146             #    for k in by_eta_phi:
0147             #        #if there were genparticles with the same eta,phi, assume it was a photon with nuclear interaction
0148             #        if len(by_eta_phi[k])>=2:
0149             #            #print(by_eta_phi[k][0])
0150             #            rg.nodes[by_eta_phi[k][0]]["typ"] = 22
0151             #            rg.nodes[by_eta_phi[k][0]]["e"] += sum(rg.nodes[n]["e"] for n in by_eta_phi[k][1:])
0152             #            rg.nodes[by_eta_phi[k][0]]["pt"] = 0 #fixme
0153             #            nodes_to_remove += by_eta_phi[k][1:]
0154             
0155             #remove links that don't contribute above a threshold 
0156             ew = [((node, node2), rg.edges[node, node2]["weight"]) for node2 in rg.neighbors(node)]
0157             ew = filter(lambda x: x[1] != 1.0, ew)
0158             ew = sorted(ew, key=lambda x: x[1], reverse=True)
0159             if len(ew) > 1:
0160                 max_in = ew[0][1]
0161                 for e, w in ew[1:]:
0162                     if w / max_in < edge_fraction_threshold:
0163                         edges_to_remove += [e]
0164     
0165     rg.remove_edges_from(edges_to_remove)        
0166     rg.remove_nodes_from(nodes_to_remove)        
0167     g = rg.reverse()
0168     
0169     #remove genparticles not linked to any elements
0170     nodes_to_remove = []
0171     for node in g.nodes:
0172         if node[0]=="sc" or node[0]=="tp":
0173             deg = g.degree[node]
0174             if deg==0:
0175                 nodes_to_remove += [node]
0176     g.remove_nodes_from(nodes_to_remove)
0177    
0178     #compute number of children and parents, save on node for visualization 
0179     for node in g.nodes:
0180         g.nodes[node]["children"] = len(list(g.neighbors(node)))
0181     
0182     rg = g.reverse()
0183     
0184     for node in rg.nodes:
0185         g.nodes[node]["parents"] = len(list(rg.neighbors(node)))
0186         rg.nodes[node]["parents"] = len(list(rg.neighbors(node)))
0187 
0188     return g
0189 
0190 def prepare_normalized_table(g, genparticle_energy_threshold=0.2):
0191     rg = g.reverse()
0192 
0193     all_genparticles = []
0194     all_elements = []
0195     all_pfcandidates = []
0196     for node in rg.nodes:
0197         if node[0] == "elem":
0198             all_elements += [node]
0199             for parent in rg.neighbors(node):
0200                 all_genparticles += [parent]
0201         elif node[0] == "pfcand":
0202             all_pfcandidates += [node]
0203     all_genparticles = list(set(all_genparticles))
0204     all_elements = sorted(all_elements)
0205 
0206     #assign genparticles in reverse pt order uniquely to best element
0207     elem_to_gp = {}
0208     unmatched_gp = []
0209     for gp in sorted(all_genparticles, key=lambda x: g.nodes[x]["pt"], reverse=True):
0210         elems = [e for e in g.neighbors(gp)]
0211 
0212         #don't assign any genparticle to these elements (PS, BREM, SC)
0213         elems = [e for e in elems if not (g.nodes[e]["typ"] in [2,3,7,10])]
0214 
0215         #sort elements by energy from genparticle
0216         elems_sorted = sorted([(g.edges[gp, e]["weight"], e) for e in elems], key=lambda x: x[0], reverse=True)
0217         
0218         if len(elems_sorted) == 0:
0219             continue
0220 
0221         chosen_elem = None
0222         for _, elem in elems_sorted:
0223             if not (elem in elem_to_gp):
0224                 chosen_elem = elem
0225                 elem_to_gp[elem] = []
0226                 break
0227         if chosen_elem is None:
0228             unmatched_gp += [gp]
0229         else:
0230             elem_to_gp[elem] += [gp]
0231 
0232     #assign unmatched genparticles to best element, allowing for overlaps
0233     for gp in sorted(unmatched_gp, key=lambda x: g.nodes[x]["pt"], reverse=True):
0234         elems = [e for e in g.neighbors(gp)]
0235         #we don't want to assign any genparticles to PS, BREM or SC - links are not reliable
0236         elems = [e for e in elems if not (g.nodes[e]["typ"] in [2,3,7,10])]
0237         elems_sorted = sorted([(g.edges[gp, e]["weight"], e) for e in elems], key=lambda x: x[0], reverse=True)
0238         _, elem = elems_sorted[0]
0239         elem_to_gp[elem] += [gp]
0240  
0241     unmatched_cand = [] 
0242     elem_to_cand = {} 
0243     for cand in sorted(all_pfcandidates, key=lambda x: g.nodes[x]["pt"], reverse=True):
0244         tp = g.nodes[cand]["typ"]
0245         neighbors = list(rg.neighbors(cand))
0246 
0247         chosen_elem = None
0248 
0249         #Pions and muons will be assigned to tracks
0250         if abs(tp) == 211 or abs(tp) == 13:
0251             for elem in neighbors:
0252                 tp_neighbor = g.nodes[elem]["typ"]
0253                 if tp_neighbor == 1:
0254                     if not (elem in elem_to_cand):
0255                         chosen_elem = elem
0256                         elem_to_cand[elem] = cand
0257                         break
0258         #other particles will be assigned to the highest-energy cluster (ECAL, HCAL, HFEM, HFHAD, SC)
0259         else:
0260             neighbors = [n for n in neighbors if g.nodes[n]["typ"] in [4,5,8,9,10]]
0261             sorted_neighbors = sorted(neighbors, key=lambda x: g.nodes[x]["e"], reverse=True)
0262             for elem in sorted_neighbors:
0263                 if not (elem in elem_to_cand):
0264                     chosen_elem = elem
0265                     elem_to_cand[elem] = cand
0266                     break
0267 
0268         if chosen_elem is None:
0269             print("unmatched candidate {}, {}".format(cand, g.nodes[cand]))
0270             unmatched_cand += [cand]
0271 
0272     elem_branches = [
0273         "typ", "pt", "eta", "phi", "e",
0274         "layer", "depth", "charge", "trajpoint", 
0275         "eta_ecal", "phi_ecal", "eta_hcal", "phi_hcal", "muon_dt_hits", "muon_csc_hits"
0276     ]
0277     target_branches = ["typ", "pt", "eta", "phi", "e", "px", "py", "pz", "charge"]
0278 
0279     Xelem = np.recarray((len(all_elements),), dtype=[(name, np.float32) for name in elem_branches])
0280     Xelem.fill(0.0)
0281     ygen = np.recarray((len(all_elements),), dtype=[(name, np.float32) for name in target_branches])
0282     ygen.fill(0.0)
0283     ycand = np.recarray((len(all_elements),), dtype=[(name, np.float32) for name in target_branches])
0284     ycand.fill(0.0)
0285  
0286     #find which elements should be linked together in the output when regressing to PFCandidates or GenParticles
0287     graph_elem_cand = nx.Graph()  
0288     graph_elem_gen = nx.Graph()  
0289     for elem in all_elements:
0290         graph_elem_cand.add_node(elem) 
0291         graph_elem_gen.add_node(elem) 
0292  
0293     for cand in all_pfcandidates:
0294         for elem1 in rg.neighbors(cand):
0295             for elem2 in rg.neighbors(cand):
0296                 if (elem1 != elem2):
0297                     graph_elem_cand.add_edge(elem1, elem2)
0298 
0299     for gp in all_genparticles:
0300         for elem1 in g.neighbors(gp):
0301             for elem2 in g.neighbors(gp):
0302                 if (elem1 != elem2):
0303                     graph_elem_gen.add_edge(elem1, elem2)
0304  
0305     for ielem, elem in enumerate(all_elements):
0306         elem_type = g.nodes[elem]["typ"]
0307         elem_eta = g.nodes[elem]["eta"]
0308         genparticles = sorted(elem_to_gp.get(elem, []), key=lambda x: g.nodes[x]["e"], reverse=True)
0309         genparticles = [gp for gp in genparticles if g.nodes[gp]["e"] > genparticle_energy_threshold]
0310         candidate = elem_to_cand.get(elem, None)
0311        
0312         lv = uproot_methods.TLorentzVector(0, 0, 0, 0)
0313        
0314         pid = 0
0315         if len(genparticles) > 0:
0316             pid = map_pdgid_to_candid.get(g.nodes[genparticles[0]]["typ"], 0)
0317 
0318         for gp in genparticles:
0319             try:
0320                 lv += uproot_methods.TLorentzVector.from_ptetaphie(
0321                     g.nodes[gp]["pt"],
0322                     g.nodes[gp]["eta"],
0323                     g.nodes[gp]["phi"],
0324                     g.nodes[gp]["e"]
0325                 )
0326             except OverflowError:
0327                 lv += uproot_methods.TLorentzVector.from_ptetaphie(
0328                     g.nodes[gp]["pt"],
0329                     np.nan,
0330                     g.nodes[gp]["phi"],
0331                     g.nodes[gp]["e"]
0332                 )
0333 
0334         if len(genparticles) > 0:
0335             if abs(elem_eta) > 3.0:
0336                 #HFHAD -> always produce hadronic candidate
0337                 if elem_type == 9:
0338                     pid = 1
0339                 #HFEM -> decide based on pid
0340                 elif elem_type == 8:
0341                     if abs(pid) in [11, 22]:
0342                         pid = 2 #produce EM candidate 
0343                     else:
0344                         pid = 1 #produce hadronic
0345 
0346             #remap PID in case of HCAL cluster
0347             if elem_type == 5 and (pid == 22 or abs(pid) == 11):
0348                 pid = 130
0349 
0350         #reproduce ROOT.TLorentzVector behavior (https://root.cern.ch/doc/master/TVector3_8cxx_source.html#l00320)
0351         try:
0352             eta = lv.eta
0353         except ZeroDivisionError:
0354             eta = np.sign(lv.z)*10e10
0355         
0356         gp = {
0357             "pt": lv.pt, "eta": eta, "phi": lv.phi, "e": lv.energy, "typ": pid, "px": lv.x, "py": lv.y, "pz": lv.z, "charge": get_charge(pid)
0358         }
0359 
0360         for j in range(len(elem_branches)):
0361             Xelem[elem_branches[j]][ielem] = g.nodes[elem][elem_branches[j]]
0362 
0363         for j in range(len(target_branches)):
0364             if not (candidate is None):
0365                 ycand[target_branches[j]][ielem] = g.nodes[candidate][target_branches[j]]
0366             ygen[target_branches[j]][ielem] = gp[target_branches[j]]
0367 
0368     dm_elem_cand = scipy.sparse.coo_matrix(nx.to_numpy_matrix(graph_elem_cand, nodelist=all_elements))
0369     dm_elem_gen = scipy.sparse.coo_matrix(nx.to_numpy_matrix(graph_elem_gen, nodelist=all_elements))
0370     return Xelem, ycand, ygen, dm_elem_cand, dm_elem_gen
0371 #end of prepare_normalized_table
0372 
0373 
0374 def process(args):
0375     infile = args.input
0376     outpath = os.path.join(args.outpath, os.path.basename(infile).split(".")[0])
0377     tf = uproot.open(infile)
0378     tt = tf["ana/pftree"]
0379 
0380     events_to_process = [i for i in range(tt.numentries)] 
0381     if not (args.event is None):
0382         events_to_process = [args.event]
0383 
0384     all_data = []
0385     ifile = 0
0386     for iev in events_to_process:
0387         print("processing event {}".format(iev))
0388 
0389         ev = tt.arrays(flatten=True,entrystart=iev,entrystop=iev+1)
0390         
0391         element_type = ev[b'element_type']
0392         element_pt = ev[b'element_pt']
0393         element_e = ev[b'element_energy']
0394         element_eta = ev[b'element_eta']
0395         element_phi = ev[b'element_phi']
0396         element_eta_ecal = ev[b'element_eta_ecal']
0397         element_phi_ecal = ev[b'element_phi_ecal']
0398         element_eta_hcal = ev[b'element_eta_hcal']
0399         element_phi_hcal = ev[b'element_phi_hcal']
0400         element_trajpoint = ev[b'element_trajpoint']
0401         element_layer = ev[b'element_layer']
0402         element_charge = ev[b'element_charge']
0403         element_depth = ev[b'element_depth']
0404         element_deltap = ev[b'element_deltap']
0405         element_sigmadeltap = ev[b'element_sigmadeltap']
0406         element_px = ev[b'element_px']
0407         element_py = ev[b'element_py']
0408         element_pz = ev[b'element_pz']
0409         element_muon_dt_hits = ev[b'element_muon_dt_hits']
0410         element_muon_csc_hits = ev[b'element_muon_csc_hits']
0411 
0412         trackingparticle_pid = ev[b'trackingparticle_pid']
0413         trackingparticle_pt = ev[b'trackingparticle_pt']
0414         trackingparticle_e = ev[b'trackingparticle_energy']
0415         trackingparticle_eta = ev[b'trackingparticle_eta']
0416         trackingparticle_phi = ev[b'trackingparticle_phi']
0417         trackingparticle_phi = ev[b'trackingparticle_phi']
0418         trackingparticle_px = ev[b'trackingparticle_px']
0419         trackingparticle_py = ev[b'trackingparticle_py']
0420         trackingparticle_pz = ev[b'trackingparticle_pz']
0421 
0422         simcluster_pid = ev[b'simcluster_pid']
0423         simcluster_pt = ev[b'simcluster_pt']
0424         simcluster_e = ev[b'simcluster_energy']
0425         simcluster_eta = ev[b'simcluster_eta']
0426         simcluster_phi = ev[b'simcluster_phi']
0427         simcluster_px = ev[b'simcluster_px']
0428         simcluster_py = ev[b'simcluster_py']
0429         simcluster_pz = ev[b'simcluster_pz']
0430 
0431         simcluster_idx_trackingparticle = ev[b'simcluster_idx_trackingparticle']
0432         pfcandidate_pdgid = ev[b'pfcandidate_pdgid']
0433         pfcandidate_pt = ev[b'pfcandidate_pt']
0434         pfcandidate_e = ev[b'pfcandidate_energy']
0435         pfcandidate_eta = ev[b'pfcandidate_eta']
0436         pfcandidate_phi = ev[b'pfcandidate_phi']
0437         pfcandidate_px = ev[b'pfcandidate_px']
0438         pfcandidate_py = ev[b'pfcandidate_py']
0439         pfcandidate_pz = ev[b'pfcandidate_pz']
0440 
0441         g = nx.DiGraph()
0442         for iobj in range(len(element_type)):
0443             g.add_node(("elem", iobj),
0444                 typ=element_type[iobj],
0445                 pt=element_pt[iobj],
0446                 e=element_e[iobj],
0447                 eta=element_eta[iobj],
0448                 phi=element_phi[iobj],
0449                 eta_ecal=element_eta_ecal[iobj],
0450                 phi_ecal=element_phi_ecal[iobj],
0451                 eta_hcal=element_eta_hcal[iobj],
0452                 phi_hcal=element_phi_hcal[iobj],
0453                 trajpoint=element_trajpoint[iobj],
0454                 layer=element_layer[iobj],
0455                 charge=element_charge[iobj],
0456                 depth=element_depth[iobj],
0457                 deltap=element_deltap[iobj],
0458                 sigmadeltap=element_sigmadeltap[iobj],
0459                 px=element_px[iobj],
0460                 py=element_py[iobj],
0461                 pz=element_pz[iobj],
0462                 muon_dt_hits=element_muon_dt_hits[iobj],
0463                 muon_csc_hits=element_muon_csc_hits[iobj],
0464             )
0465         for iobj in range(len(trackingparticle_pid)):
0466             g.add_node(("tp", iobj),
0467                 typ=trackingparticle_pid[iobj],
0468                 pt=trackingparticle_pt[iobj],
0469                 e=trackingparticle_e[iobj],
0470                 eta=trackingparticle_eta[iobj],
0471                 phi=trackingparticle_phi[iobj],
0472                 px=trackingparticle_px[iobj],
0473                 py=trackingparticle_py[iobj],
0474                 pz=trackingparticle_pz[iobj],
0475             )
0476         for iobj in range(len(simcluster_pid)):
0477             g.add_node(("sc", iobj),
0478                 typ=simcluster_pid[iobj],
0479                 pt=simcluster_pt[iobj],
0480                 e=simcluster_e[iobj],
0481                 eta=simcluster_eta[iobj],
0482                 phi=simcluster_phi[iobj],
0483                 px=simcluster_px[iobj],
0484                 py=simcluster_py[iobj],
0485                 pz=simcluster_pz[iobj],
0486             )
0487 
0488         trackingparticle_to_element_first = ev[b'trackingparticle_to_element.first']
0489         trackingparticle_to_element_second = ev[b'trackingparticle_to_element.second']
0490         #for trackingparticles associated to elements, set a very high edge weight
0491         for tp, elem in zip(trackingparticle_to_element_first, trackingparticle_to_element_second):
0492             g.add_edge(("tp", tp), ("elem", elem), weight=99999.0)
0493  
0494         simcluster_to_element_first = ev[b'simcluster_to_element.first']
0495         simcluster_to_element_second = ev[b'simcluster_to_element.second']
0496         simcluster_to_element_cmp = ev[b'simcluster_to_element_cmp']
0497         for sc, elem, c in zip(simcluster_to_element_first, simcluster_to_element_second, simcluster_to_element_cmp):
0498             g.add_edge(("sc", sc), ("elem", elem), weight=c)
0499 
0500         print("contracting nodes: trackingparticle to simcluster")
0501         nodes_to_remove = []
0502         for idx_sc, idx_tp in enumerate(simcluster_idx_trackingparticle):
0503             if idx_tp != -1:
0504                 for elem in g.neighbors(("sc", idx_sc)):
0505                     g.add_edge(("tp", idx_tp), elem, weight=g.edges[("sc", idx_sc), elem]["weight"]) 
0506                 g.nodes[("tp", idx_tp)]["idx_sc"] = idx_sc   
0507                 nodes_to_remove += [("sc", idx_sc)]
0508         g.remove_nodes_from(nodes_to_remove)
0509 
0510         for iobj in range(len(pfcandidate_pdgid)):
0511             g.add_node(("pfcand", iobj),
0512                 typ=pfcandidate_pdgid[iobj],
0513                 pt=pfcandidate_pt[iobj],
0514                 e=pfcandidate_e[iobj],
0515                 eta=pfcandidate_eta[iobj],
0516                 phi=pfcandidate_phi[iobj],
0517                 px=pfcandidate_px[iobj],
0518                 py=pfcandidate_py[iobj],
0519                 pz=pfcandidate_pz[iobj],
0520                 charge=get_charge(pfcandidate_pdgid[iobj]),
0521             )
0522 
0523         element_to_candidate_first = ev[b'element_to_candidate.first']
0524         element_to_candidate_second = ev[b'element_to_candidate.second']
0525         for elem, pfcand in zip(element_to_candidate_first, element_to_candidate_second):
0526             g.add_edge(("elem", elem), ("pfcand", pfcand), weight=1.0)
0527         print("Graph created: {} nodes, {} edges".format(len(g.nodes), len(g.edges)))
0528  
0529         g = cleanup_graph(g)
0530         rg = g.reverse()
0531 
0532         #make tree visualizations for PFCandidates
0533         ncand = 0 
0534         for node in sorted(filter(lambda x: x[0]=="pfcand", g.nodes), key=lambda x: g.nodes[x]["pt"], reverse=True):
0535             if ncand < args.plot_candidates:
0536                 print(node, g.nodes[node]["pt"])
0537                 fig = save_ego_graph(rg, node, 3, False)
0538                 plt.savefig(outpath + "_ev_{}_cand_{}_idx_{}.pdf".format(iev, ncand, node[1]), bbox_inches="tight")
0539                 plt.clf()
0540                 del fig
0541             ncand += 1
0542 
0543         #fig = draw_event(g)
0544         #plt.savefig(outpath + "_ev_{}.pdf".format(iev))
0545         #plt.clf()
0546 
0547         #do one-to-one associations
0548         Xelem, ycand, ygen, dm_elem_cand, dm_elem_gen = prepare_normalized_table(g)
0549         #dm = prepare_elem_distance_matrix(ev)
0550         data = {}
0551 
0552         if args.save_normalized_table:
0553             data = {
0554                 "Xelem": Xelem,
0555                 "ycand": ycand,
0556                 "ygen": ygen,
0557                 #"dm": dm,
0558                 "dm_elem_cand": dm_elem_cand,
0559                 "dm_elem_gen": dm_elem_gen
0560             }
0561 
0562         if args.save_full_graph:
0563             data["full_graph"] = g
0564 
0565         all_data += [data]
0566 
0567         if args.events_per_file > 0:
0568             if len(all_data) == args.events_per_file:
0569                 print(outpath + "_{}.pkl".format(ifile))
0570                 with open(outpath + "_{}.pkl".format(ifile), "wb") as fi:
0571                     pickle.dump(all_data, fi)
0572                 ifile += 1
0573                 all_data = []
0574 
0575     if args.events_per_file == -1:
0576         print(outpath)
0577         with open(outpath + ".pkl", "wb") as fi:
0578             pickle.dump(all_data, fi)
0579 
0580 def parse_args():
0581     import argparse
0582     parser = argparse.ArgumentParser()
0583     parser.add_argument("--input", type=str, help="Input file from PFAnalysis", required=True)
0584     parser.add_argument("--event", type=int, default=None, help="event index to process, omit to process all")
0585     parser.add_argument("--outpath", type=str, default="raw", help="output path")
0586     parser.add_argument("--plot-candidates", type=int, default=0, help="number of PFCandidates to plot as trees in pt-descending order")
0587     parser.add_argument("--events-per-file", type=int, default=-1, help="number of events per output file, -1 for all")
0588     parser.add_argument("--save-full-graph", action="store_true", help="save the full event graph")
0589     parser.add_argument("--save-normalized-table", action="store_true", help="save the uniquely identified table")
0590     args = parser.parse_args()
0591     return args
0592 
0593 if __name__ == "__main__":
0594     args = parse_args()
0595     process(args)
0596