File indexing completed on 2023-10-25 10:01:46
0001 #include "RecoParticleFlow/PFProducer/interface/MLPFModel.h"
0002 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0003 #include "FWCore/Utilities/interface/isFinite.h"
0004 #include "DataFormats/ParticleFlowReco/interface/PFCluster.h"
0005 #include "DataFormats/ParticleFlowReco/interface/PFBlock.h"
0006 #include "DataFormats/ParticleFlowReco/interface/PFBlockElementSuperCluster.h"
0007 #include "DataFormats/ParticleFlowReco/interface/PFBlockElementGsfTrack.h"
0008 #include "DataFormats/ParticleFlowReco/interface/PFBlockElementTrack.h"
0009 #include "DataFormats/ParticleFlowReco/interface/PFBlockElementBrem.h"
0010 #include "DataFormats/ParticleFlowReco/interface/PFBlockElementCluster.h"
0011 #include "DataFormats/EgammaReco/interface/SuperCluster.h"
0012 #include "DataFormats/EgammaReco/interface/ElectronSeed.h"
0013
0014 namespace reco::mlpf {
0015
0016
0017 std::array<float, NUM_ELEMENT_FEATURES> getElementProperties(const reco::PFBlockElement& orig) {
0018 const auto type = orig.type();
0019 float pt = 0.0;
0020 float deltap = 0.0;
0021 float sigmadeltap = 0.0;
0022 float px = 0.0;
0023 float py = 0.0;
0024 float pz = 0.0;
0025 float eta = 0.0;
0026 float phi = 0.0;
0027 float energy = 0.0;
0028 float corr_energy = 0.0;
0029 float trajpoint = 0.0;
0030 float eta_ecal = 0.0;
0031 float phi_ecal = 0.0;
0032 float eta_hcal = 0.0;
0033 float phi_hcal = 0.0;
0034 float charge = 0;
0035 float layer = 0;
0036 float depth = 0;
0037 float muon_dt_hits = 0.0;
0038 float muon_csc_hits = 0.0;
0039 float muon_type = 0.0;
0040 float cluster_flags = 0.0;
0041 float gsf_electronseed_trkorecal = 0.0;
0042 float num_hits = 0.0;
0043
0044 if (type == reco::PFBlockElement::TRACK) {
0045 const auto& matched_pftrack = orig.trackRefPF();
0046 if (matched_pftrack.isNonnull()) {
0047 const auto& atECAL = matched_pftrack->extrapolatedPoint(reco::PFTrajectoryPoint::ECALShowerMax);
0048 const auto& atHCAL = matched_pftrack->extrapolatedPoint(reco::PFTrajectoryPoint::HCALEntrance);
0049 if (atHCAL.isValid()) {
0050 eta_hcal = atHCAL.positionREP().eta();
0051 phi_hcal = atHCAL.positionREP().phi();
0052 }
0053 if (atECAL.isValid()) {
0054 eta_ecal = atECAL.positionREP().eta();
0055 phi_ecal = atECAL.positionREP().phi();
0056 }
0057 }
0058 const auto& ref = ((const reco::PFBlockElementTrack*)&orig)->trackRef();
0059 pt = ref->pt();
0060 px = ref->px();
0061 py = ref->py();
0062 pz = ref->pz();
0063 eta = ref->eta();
0064 phi = ref->phi();
0065 energy = ref->p();
0066 charge = ref->charge();
0067 num_hits = ref->recHitsSize();
0068
0069 reco::MuonRef muonRef = orig.muonRef();
0070 if (muonRef.isNonnull()) {
0071 reco::TrackRef standAloneMu = muonRef->standAloneMuon();
0072 if (standAloneMu.isNonnull()) {
0073 muon_dt_hits = standAloneMu->hitPattern().numberOfValidMuonDTHits();
0074 muon_csc_hits = standAloneMu->hitPattern().numberOfValidMuonCSCHits();
0075 }
0076 muon_type = muonRef->type();
0077 }
0078
0079 } else if (type == reco::PFBlockElement::BREM) {
0080 const auto* orig2 = (const reco::PFBlockElementBrem*)&orig;
0081 const auto& ref = orig2->GsftrackRef();
0082 trajpoint = orig2->indTrajPoint();
0083 if (ref.isNonnull()) {
0084 deltap = orig2->DeltaP();
0085 sigmadeltap = orig2->SigmaDeltaP();
0086 pt = ref->pt();
0087 px = ref->px();
0088 py = ref->py();
0089 pz = ref->pz();
0090 eta = ref->eta();
0091 phi = ref->phi();
0092 energy = ref->p();
0093 charge = ref->charge();
0094 }
0095
0096 const auto& gsfextraref = ref->extra();
0097 if (gsfextraref.isAvailable() && gsfextraref->seedRef().isAvailable()) {
0098 reco::ElectronSeedRef seedref = gsfextraref->seedRef().castTo<reco::ElectronSeedRef>();
0099 if (seedref.isAvailable()) {
0100 if (seedref->isEcalDriven()) {
0101 gsf_electronseed_trkorecal = 1.0;
0102 } else if (seedref->isTrackerDriven()) {
0103 gsf_electronseed_trkorecal = 2.0;
0104 }
0105 }
0106 }
0107
0108 } else if (type == reco::PFBlockElement::GSF) {
0109
0110 const auto* orig2 = (const reco::PFBlockElementGsfTrack*)&orig;
0111 const auto& vec = orig2->Pin();
0112 pt = vec.pt();
0113 px = vec.px();
0114 py = vec.py();
0115 pz = vec.pz();
0116 eta = vec.eta();
0117 phi = vec.phi();
0118 energy = vec.energy();
0119
0120 const auto& vec2 = orig2->Pout();
0121 eta_ecal = vec2.eta();
0122 phi_ecal = vec2.phi();
0123
0124 if (!orig2->GsftrackRefPF().isNull()) {
0125 charge = orig2->GsftrackRefPF()->charge();
0126 num_hits = orig2->GsftrackRefPF()->PFRecBrem().size();
0127 }
0128
0129 const auto& ref = orig2->GsftrackRef();
0130
0131 const auto& gsfextraref = ref->extra();
0132 if (gsfextraref.isAvailable() && gsfextraref->seedRef().isAvailable()) {
0133 reco::ElectronSeedRef seedref = gsfextraref->seedRef().castTo<reco::ElectronSeedRef>();
0134 if (seedref.isAvailable()) {
0135 if (seedref->isEcalDriven()) {
0136 gsf_electronseed_trkorecal = 1.0;
0137 } else if (seedref->isTrackerDriven()) {
0138 gsf_electronseed_trkorecal = 2.0;
0139 }
0140 }
0141 };
0142
0143 } else if (type == reco::PFBlockElement::ECAL || type == reco::PFBlockElement::PS1 ||
0144 type == reco::PFBlockElement::PS2 || type == reco::PFBlockElement::HCAL ||
0145 type == reco::PFBlockElement::HO || type == reco::PFBlockElement::HFHAD ||
0146 type == reco::PFBlockElement::HFEM) {
0147 const auto& ref = ((const reco::PFBlockElementCluster*)&orig)->clusterRef();
0148 if (ref.isNonnull()) {
0149 cluster_flags = ref->flags();
0150 eta = ref->eta();
0151 phi = ref->phi();
0152 pt = ref->pt();
0153 px = ref->position().x();
0154 py = ref->position().y();
0155 pz = ref->position().z();
0156 energy = ref->energy();
0157 corr_energy = ref->correctedEnergy();
0158 layer = ref->layer();
0159 depth = ref->depth();
0160 num_hits = ref->recHitFractions().size();
0161 }
0162 } else if (type == reco::PFBlockElement::SC) {
0163 const auto& clref = ((const reco::PFBlockElementSuperCluster*)&orig)->superClusterRef();
0164 if (clref.isNonnull()) {
0165 cluster_flags = clref->flags();
0166 eta = clref->eta();
0167 phi = clref->phi();
0168 px = clref->position().x();
0169 py = clref->position().y();
0170 pz = clref->position().z();
0171 energy = clref->energy();
0172 num_hits = clref->clustersSize();
0173 }
0174 }
0175
0176 float typ_idx = static_cast<float>(elem_type_encoding.at(orig.type()));
0177
0178
0179 return {{typ_idx,
0180 pt,
0181 eta,
0182 phi,
0183 energy,
0184 layer,
0185 depth,
0186 charge,
0187 trajpoint,
0188 eta_ecal,
0189 phi_ecal,
0190 eta_hcal,
0191 phi_hcal,
0192 muon_dt_hits,
0193 muon_csc_hits,
0194 muon_type,
0195 px,
0196 py,
0197 pz,
0198 deltap,
0199 sigmadeltap,
0200 gsf_electronseed_trkorecal,
0201 num_hits,
0202 cluster_flags,
0203 corr_energy}};
0204 }
0205
0206
0207 float normalize(float in) {
0208 if (std::abs(in) > 1e4f) {
0209 return 0.0;
0210 } else if (edm::isNotFinite(in)) {
0211 return 0.0;
0212 }
0213 return in;
0214 }
0215
0216 int argMax(std::vector<float> const& vec) {
0217 return static_cast<int>(std::distance(vec.begin(), max_element(vec.begin(), vec.end())));
0218 }
0219
0220 reco::PFCandidate makeCandidate(int pred_pid,
0221 int pred_charge,
0222 float pred_pt,
0223 float pred_eta,
0224 float pred_sin_phi,
0225 float pred_cos_phi,
0226 float pred_e) {
0227 float pred_phi = std::atan2(pred_sin_phi, pred_cos_phi);
0228
0229
0230 reco::PFCandidate::Charge charge = 0;
0231 if (pred_pid == 11 || pred_pid == 13 || pred_pid == 211) {
0232 charge = pred_charge > 0 ? +1 : -1;
0233 }
0234
0235 math::PtEtaPhiELorentzVectorD p4(pred_pt, pred_eta, pred_phi, pred_e);
0236
0237 reco::PFCandidate::ParticleType particleType(reco::PFCandidate::X);
0238 if (pred_pid == 211)
0239 particleType = reco::PFCandidate::h;
0240 else if (pred_pid == 130)
0241 particleType = reco::PFCandidate::h0;
0242 else if (pred_pid == 22)
0243 particleType = reco::PFCandidate::gamma;
0244 else if (pred_pid == 11)
0245 particleType = reco::PFCandidate::e;
0246 else if (pred_pid == 13)
0247 particleType = reco::PFCandidate::mu;
0248 else if (pred_pid == 1)
0249 particleType = reco::PFCandidate::h_HF;
0250 else if (pred_pid == 2)
0251 particleType = reco::PFCandidate::egamma_HF;
0252
0253 reco::PFCandidate cand(charge, math::XYZTLorentzVector(p4.X(), p4.Y(), p4.Z(), p4.E()), particleType);
0254 cand.setMass(0.0);
0255 if (pred_pid == 211)
0256 cand.setMass(PI_MASS);
0257
0258
0259
0260 return cand;
0261 }
0262
0263 const std::vector<const reco::PFBlockElement*> getPFElements(const reco::PFBlockCollection& blocks) {
0264 std::vector<reco::PFCandidate> pOutputCandidateCollection;
0265
0266 std::vector<const reco::PFBlockElement*> all_elements;
0267 for (const auto& block : blocks) {
0268 const auto& elems = block.elements();
0269 for (const auto& elem : elems) {
0270 if (all_elements.size() < NUM_MAX_ELEMENTS_BATCH) {
0271 all_elements.push_back(&elem);
0272 } else {
0273
0274 edm::LogError("MLPFProducer") << "too many input PFElements for predefined model size: " << elems.size();
0275 break;
0276 }
0277 }
0278 }
0279 return all_elements;
0280 }
0281
0282
0283
0284 void setCandidateRefs(reco::PFCandidate& cand,
0285 const std::vector<const reco::PFBlockElement*> elems,
0286 size_t ielem_originator) {
0287 const reco::PFBlockElement* elem = elems[ielem_originator];
0288
0289
0290 if (std::abs(cand.pdgId()) == 211 && elem->type() == reco::PFBlockElement::TRACK && elem->trackRef().isNonnull()) {
0291 const auto* eltTrack = dynamic_cast<const reco::PFBlockElementTrack*>(elem);
0292 cand.setTrackRef(eltTrack->trackRef());
0293 cand.setVertex(eltTrack->trackRef()->vertex());
0294 cand.setPositionAtECALEntrance(eltTrack->positionAtECALEntrance());
0295 }
0296
0297
0298 if (std::abs(cand.pdgId()) == 13) {
0299 const auto* eltTrack = dynamic_cast<const reco::PFBlockElementTrack*>(elem);
0300 const auto& muonRef = eltTrack->muonRef();
0301 cand.setTrackRef(muonRef->track());
0302 cand.setMuonTrackType(muonRef->muonBestTrackType());
0303 cand.setVertex(muonRef->track()->vertex());
0304 cand.setMuonRef(muonRef);
0305 }
0306
0307 if (std::abs(cand.pdgId()) == 11 && elem->type() == reco::PFBlockElement::GSF) {
0308 const auto* eltTrack = dynamic_cast<const reco::PFBlockElementGsfTrack*>(elem);
0309 const auto& ref = eltTrack->GsftrackRef();
0310 cand.setGsfTrackRef(ref);
0311 }
0312 }
0313
0314 };