Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-02-12 09:07:53

0001 #include "FWCore/Framework/interface/Frameworkfwd.h"
0002 #include "FWCore/Framework/interface/stream/EDProducer.h"
0003 #include "FWCore/Framework/interface/Event.h"
0004 #include "FWCore/Framework/interface/MakerMacros.h"
0005 
0006 #include "DataFormats/ParticleFlowCandidate/interface/PFCandidate.h"
0007 #include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
0008 #include "RecoParticleFlow/PFProducer/interface/MLPFModel.h"
0009 
0010 #include "DataFormats/ParticleFlowReco/interface/PFBlockElementTrack.h"
0011 
0012 using namespace cms::Ort;
0013 
0014 //use this to switch on detailed print statements in MLPF
0015 //#define MLPF_DEBUG
0016 
0017 class MLPFProducer : public edm::stream::EDProducer<edm::GlobalCache<ONNXRuntime>> {
0018 public:
0019   explicit MLPFProducer(const edm::ParameterSet&, const ONNXRuntime*);
0020 
0021   void produce(edm::Event& event, const edm::EventSetup& setup) override;
0022   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0023 
0024   // static methods for handling the global cache
0025   static std::unique_ptr<ONNXRuntime> initializeGlobalCache(const edm::ParameterSet&);
0026   static void globalEndJob(const ONNXRuntime*);
0027 
0028 private:
0029   const edm::EDPutTokenT<reco::PFCandidateCollection> pfCandidatesPutToken_;
0030   const edm::EDGetTokenT<reco::PFBlockCollection> inputTagBlocks_;
0031 };
0032 
0033 MLPFProducer::MLPFProducer(const edm::ParameterSet& cfg, const ONNXRuntime* cache)
0034     : pfCandidatesPutToken_{produces<reco::PFCandidateCollection>()},
0035       inputTagBlocks_(consumes<reco::PFBlockCollection>(cfg.getParameter<edm::InputTag>("src"))) {}
0036 
0037 void MLPFProducer::produce(edm::Event& event, const edm::EventSetup& setup) {
0038   using namespace reco::mlpf;
0039 
0040   const auto& blocks = event.get(inputTagBlocks_);
0041   const auto& all_elements = getPFElements(blocks);
0042 
0043   std::vector<const reco::PFBlockElement*> selected_elements;
0044   unsigned int num_elements_total = 0;
0045   for (const auto* pelem : all_elements) {
0046     if (pelem->type() == reco::PFBlockElement::PS1 || pelem->type() == reco::PFBlockElement::PS2) {
0047       continue;
0048     }
0049     num_elements_total += 1;
0050     selected_elements.push_back(pelem);
0051   }
0052   const auto tensor_size = LSH_BIN_SIZE * std::max(2u, (num_elements_total / LSH_BIN_SIZE + 1));
0053 
0054 #ifdef MLPF_DEBUG
0055   assert(num_elements_total < NUM_MAX_ELEMENTS_BATCH);
0056   //tensor size must be a multiple of the bin size and larger than the number of elements
0057   assert(tensor_size <= NUM_MAX_ELEMENTS_BATCH);
0058   assert(tensor_size % LSH_BIN_SIZE == 0);
0059 #endif
0060 
0061 #ifdef MLPF_DEBUG
0062   std::cout << "tensor_size=" << tensor_size << std::endl;
0063 #endif
0064 
0065   //Fill the input tensor (batch, elems, features) = (1, tensor_size, NUM_ELEMENT_FEATURES)
0066   std::vector<std::vector<float>> inputs(1, std::vector<float>(NUM_ELEMENT_FEATURES * tensor_size, 0.0));
0067   unsigned int ielem = 0;
0068   for (const auto* pelem : selected_elements) {
0069     if (ielem > tensor_size) {
0070       continue;
0071     }
0072 
0073     const auto& elem = *pelem;
0074 
0075     //prepare the input array from the PFElement
0076     const auto& props = getElementProperties(elem);
0077 
0078     //copy features to the input array
0079     for (unsigned int iprop = 0; iprop < NUM_ELEMENT_FEATURES; iprop++) {
0080       inputs[0][ielem * NUM_ELEMENT_FEATURES + iprop] = normalize(props[iprop]);
0081     }
0082     ielem += 1;
0083   }
0084 
0085   //run the GNN inference, given the inputs and the output.
0086   const auto& outputs = globalCache()->run({"x:0"}, inputs, {{1, tensor_size, NUM_ELEMENT_FEATURES}});
0087   const auto& output = outputs[0];
0088 #ifdef MLPF_DEBUG
0089   assert(output.size() == tensor_size * NUM_OUTPUT_FEATURES);
0090 #endif
0091 
0092   std::vector<reco::PFCandidate> pOutputCandidateCollection;
0093   for (size_t ielem = 0; ielem < num_elements_total; ielem++) {
0094     std::vector<float> pred_id_probas(IDX_CLASS + 1, 0.0);
0095     const reco::PFBlockElement* elem = selected_elements[ielem];
0096 
0097     for (unsigned int idx_id = 0; idx_id <= IDX_CLASS; idx_id++) {
0098       auto pred_proba = output[ielem * NUM_OUTPUT_FEATURES + idx_id];
0099 #ifdef MLPF_DEBUG
0100       assert(!std::isnan(pred_proba));
0101 #endif
0102       pred_id_probas[idx_id] = pred_proba;
0103     }
0104 
0105     auto imax = argMax(pred_id_probas);
0106 
0107     //get the most probable class PDGID
0108     int pred_pid = pdgid_encoding[imax];
0109 
0110 #ifdef MLPF_DEBUG
0111     std::cout << "ielem=" << ielem << " inputs:";
0112     for (unsigned int iprop = 0; iprop < NUM_ELEMENT_FEATURES; iprop++) {
0113       std::cout << iprop << "=" << inputs[0][ielem * NUM_ELEMENT_FEATURES + iprop] << " ";
0114     }
0115     std::cout << std::endl;
0116     std::cout << "ielem=" << ielem << " pred: pid=" << pred_pid << std::endl;
0117 #endif
0118 
0119     //a particle was predicted for this PFElement, otherwise it was a spectator
0120     if (pred_pid != 0) {
0121       //muons and charged hadrons should only come from tracks, otherwise we won't have track references to pass downstream
0122       if (((pred_pid == 13) || (pred_pid == 211)) && elem->type() != reco::PFBlockElement::TRACK) {
0123         pred_pid = 130;
0124       }
0125 
0126       if (elem->type() == reco::PFBlockElement::TRACK) {
0127         const auto* eltTrack = dynamic_cast<const reco::PFBlockElementTrack*>(elem);
0128 
0129         //a track with no muon ref should not produce a muon candidate, instead we interpret it as a charged hadron
0130         if (pred_pid == 13 && eltTrack->muonRef().isNull()) {
0131           pred_pid = 211;
0132         }
0133 
0134         //tracks from displaced vertices need reference debugging downstream as well, so we just treat them as neutrals for the moment
0135         if ((pred_pid == 211) && (eltTrack->isLinkedToDisplacedVertex())) {
0136           pred_pid = 130;
0137         }
0138       }
0139 
0140       //get the predicted momentum components
0141       float pred_pt = output[ielem * NUM_OUTPUT_FEATURES + IDX_PT];
0142       float pred_eta = output[ielem * NUM_OUTPUT_FEATURES + IDX_ETA];
0143       float pred_sin_phi = output[ielem * NUM_OUTPUT_FEATURES + IDX_SIN_PHI];
0144       float pred_cos_phi = output[ielem * NUM_OUTPUT_FEATURES + IDX_COS_PHI];
0145       float pred_e = output[ielem * NUM_OUTPUT_FEATURES + IDX_ENERGY];
0146       float pred_charge = output[ielem * NUM_OUTPUT_FEATURES + IDX_CHARGE];
0147 
0148       auto cand = makeCandidate(pred_pid, pred_charge, pred_pt, pred_eta, pred_sin_phi, pred_cos_phi, pred_e);
0149       setCandidateRefs(cand, selected_elements, ielem);
0150       pOutputCandidateCollection.push_back(cand);
0151 
0152 #ifdef MLPF_DEBUG
0153       std::cout << "ielem=" << ielem << " cand: pid=" << cand.pdgId() << " E=" << cand.energy() << " pt=" << cand.pt()
0154                 << " eta=" << cand.eta() << " phi=" << cand.phi() << " charge=" << cand.charge() << std::endl;
0155 #endif
0156     }
0157   }  //loop over PFElements
0158 
0159   event.emplace(pfCandidatesPutToken_, pOutputCandidateCollection);
0160 }
0161 
0162 std::unique_ptr<ONNXRuntime> MLPFProducer::initializeGlobalCache(const edm::ParameterSet& params) {
0163   return std::make_unique<ONNXRuntime>(params.getParameter<edm::FileInPath>("model_path").fullPath());
0164 }
0165 
0166 void MLPFProducer::globalEndJob(const ONNXRuntime* cache) {}
0167 
0168 void MLPFProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0169   edm::ParameterSetDescription desc;
0170   desc.add<edm::InputTag>("src", edm::InputTag("particleFlowBlock"));
0171   desc.add<edm::FileInPath>(
0172       "model_path",
0173       edm::FileInPath(
0174           "RecoParticleFlow/PFProducer/data/mlpf/"
0175           "mlpf_2021_11_16__no_einsum__all_data_cms-best-of-asha-scikit_20211026_042043_178263.workergpu010.onnx"));
0176   descriptions.addWithDefaultLabel(desc);
0177 }
0178 
0179 DEFINE_FWK_MODULE(MLPFProducer);