Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-06-04 22:36:22

0001 #include "FWCore/Framework/interface/Frameworkfwd.h"
0002 #include "FWCore/Framework/interface/stream/EDProducer.h"
0003 
0004 #include "FWCore/Framework/interface/Event.h"
0005 #include "FWCore/Framework/interface/MakerMacros.h"
0006 
0007 #include "FWCore/Framework/interface/makeRefToBaseProdFrom.h"
0008 
0009 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0010 #include "FWCore/Utilities/interface/StreamID.h"
0011 
0012 #include "DataFormats/BTauReco/interface/JetTag.h"
0013 
0014 #include "DataFormats/BTauReco/interface/DeepFlavourTagInfo.h"
0015 
0016 #include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
0017 
0018 using namespace cms::Ort;
0019 
0020 class DeepFlavourONNXJetTagsProducer : public edm::stream::EDProducer<edm::GlobalCache<ONNXRuntime>> {
0021 public:
0022   explicit DeepFlavourONNXJetTagsProducer(const edm::ParameterSet&, const ONNXRuntime*);
0023   ~DeepFlavourONNXJetTagsProducer() override;
0024 
0025   static void fillDescriptions(edm::ConfigurationDescriptions&);
0026 
0027   static std::unique_ptr<ONNXRuntime> initializeGlobalCache(const edm::ParameterSet&);
0028   static void globalEndJob(const ONNXRuntime*);
0029 
0030 private:
0031   typedef std::vector<reco::DeepFlavourTagInfo> TagInfoCollection;
0032   typedef reco::JetTagCollection JetTagCollection;
0033 
0034   void produce(edm::Event&, const edm::EventSetup&) override;
0035 
0036   void make_inputs(unsigned i_jet, const reco::DeepFlavourTagInfo& taginfo);
0037 
0038   const edm::EDGetTokenT<TagInfoCollection> src_;
0039   edm::EDGetTokenT<edm::View<reco::Jet>> jet_;
0040   std::vector<std::string> flav_names_;
0041   std::vector<std::string> input_names_;
0042   std::vector<std::string> output_names_;
0043 
0044   enum InputIndexes { kGlobal = 0, kChargedCandidates = 1, kNeutralCandidates = 2, kVertices = 3, kJetPt = 4 };
0045   constexpr static unsigned n_features_global_ = 15;
0046   constexpr static unsigned n_cpf_ = 25;
0047   constexpr static unsigned n_features_cpf_ = 16;
0048   constexpr static unsigned n_npf_ = 25;
0049   constexpr static unsigned n_features_npf_ = 6;
0050   constexpr static unsigned n_sv_ = 4;
0051   constexpr static unsigned n_features_sv_ = 12;
0052   constexpr static unsigned n_features_jetpt_ = 1;
0053   const static std::vector<unsigned> input_sizes_;
0054 
0055   // hold the input data
0056   FloatArrays data_;
0057   bool produceValueMap_;
0058   edm::Handle<edm::View<reco::Jet>> jets;
0059 };
0060 
0061 const std::vector<unsigned> DeepFlavourONNXJetTagsProducer::input_sizes_{
0062     n_features_global_, n_cpf_* n_features_cpf_, n_npf_* n_features_npf_, n_sv_* n_features_sv_, n_features_jetpt_};
0063 
0064 DeepFlavourONNXJetTagsProducer::DeepFlavourONNXJetTagsProducer(const edm::ParameterSet& iConfig,
0065                                                                const ONNXRuntime* cache)
0066     : src_(consumes<TagInfoCollection>(iConfig.getParameter<edm::InputTag>("src"))),
0067       flav_names_(iConfig.getParameter<std::vector<std::string>>("flav_names")),
0068       input_names_(iConfig.getParameter<std::vector<std::string>>("input_names")),
0069       output_names_(iConfig.getParameter<std::vector<std::string>>("output_names")),
0070       produceValueMap_(iConfig.getUntrackedParameter<bool>("produceValueMap", false)) {
0071   if (produceValueMap_) {
0072     jet_ = consumes<edm::View<reco::Jet>>(iConfig.getParameter<edm::InputTag>("jets"));
0073   }
0074 
0075   // get output names from flav_names
0076   for (const auto& flav_name : flav_names_) {
0077     produces<JetTagCollection>(flav_name);
0078     if (produceValueMap_) {
0079       produces<edm::ValueMap<float>>(flav_name);
0080     }
0081   }
0082 
0083   assert(input_names_.size() == input_sizes_.size());
0084   data_.reserve(input_sizes_.size());
0085 }
0086 
0087 DeepFlavourONNXJetTagsProducer::~DeepFlavourONNXJetTagsProducer() {}
0088 
0089 void DeepFlavourONNXJetTagsProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0090   // pfDeepFlavourJetTags
0091   edm::ParameterSetDescription desc;
0092   desc.add<edm::InputTag>("src", edm::InputTag("pfDeepFlavourTagInfos"));
0093   desc.add<std::vector<std::string>>("input_names", {"input_1", "input_2", "input_3", "input_4", "input_5"});
0094   desc.add<edm::FileInPath>(
0095       "model_path", edm::FileInPath("RecoBTag/Combined/data/DeepFlavourV04_12X_training/DeepJet_Run3_122X.onnx"));
0096   desc.add<std::vector<std::string>>("output_names", {"ID_pred/Softmax:0"});
0097   desc.add<std::vector<std::string>>(
0098       "flav_names", std::vector<std::string>{"probb", "probbb", "problepb", "probc", "probuds", "probg"});
0099   desc.add<edm::InputTag>("jets", edm::InputTag("hltAK4PFPuppiJets"));
0100   desc.addOptionalUntracked<bool>("produceValueMap", false);
0101 
0102   descriptions.add("pfDeepFlavourJetTags", desc);
0103 }
0104 
0105 std::unique_ptr<ONNXRuntime> DeepFlavourONNXJetTagsProducer::initializeGlobalCache(const edm::ParameterSet& iConfig) {
0106   return std::make_unique<ONNXRuntime>(iConfig.getParameter<edm::FileInPath>("model_path").fullPath());
0107 }
0108 
0109 void DeepFlavourONNXJetTagsProducer::globalEndJob(const ONNXRuntime* cache) {}
0110 
0111 void DeepFlavourONNXJetTagsProducer::produce(edm::Event& iEvent, const edm::EventSetup& iSetup) {
0112   edm::Handle<TagInfoCollection> tag_infos;
0113   iEvent.getByToken(src_, tag_infos);
0114 
0115   if (produceValueMap_) {
0116     iEvent.getByToken(jet_, jets);
0117     if (!jets.isValid()) {
0118       edm::LogWarning("DeepFlavourONNXJetTagsProducer") << "Invalid handle in jet input collection";
0119       return;
0120     }
0121   }
0122 
0123   std::vector<std::unique_ptr<JetTagCollection>> output_tags;
0124   std::vector<std::vector<float>> output_scores(flav_names_.size(), std::vector<float>(tag_infos->size(), -1.0));
0125   if (!tag_infos->empty()) {
0126     // initialize output collection
0127     auto jet_ref = tag_infos->begin()->jet();
0128     auto ref2prod = edm::makeRefToBaseProdFrom(jet_ref, iEvent);
0129     for (std::size_t i = 0; i < flav_names_.size(); i++) {
0130       output_tags.emplace_back(std::make_unique<JetTagCollection>(ref2prod));
0131     }
0132 
0133     // init data storage
0134     data_.clear();
0135     for (const auto& len : input_sizes_) {
0136       data_.emplace_back(tag_infos->size() * len, 0.0);
0137     }
0138 
0139     // convert inputs
0140     for (unsigned jet_n = 0; jet_n < tag_infos->size(); ++jet_n) {
0141       const auto& taginfo = (*tag_infos)[jet_n];
0142       make_inputs(jet_n, taginfo);
0143     }
0144 
0145     // run prediction
0146     auto outputs = globalCache()->run(input_names_, data_, {}, output_names_, tag_infos->size())[0];
0147     assert(outputs.size() == flav_names_.size() * tag_infos->size());
0148 
0149     // get the outputs
0150     unsigned i_output = 0;
0151     for (unsigned jet_n = 0; jet_n < tag_infos->size(); ++jet_n) {
0152       const auto& jet_ref = tag_infos->at(jet_n).jet();
0153       for (std::size_t flav_n = 0; flav_n < flav_names_.size(); flav_n++) {
0154         (*(output_tags[flav_n]))[jet_ref] = outputs[i_output];
0155         if (produceValueMap_) {
0156           output_scores[flav_n][jet_n] = outputs[flav_n];
0157         }
0158         ++i_output;
0159       }
0160     }
0161   } else {
0162     // create empty output collection
0163     for (std::size_t i = 0; i < flav_names_.size(); i++) {
0164       output_tags.emplace_back(std::make_unique<JetTagCollection>());
0165     }
0166   }
0167 
0168   // put into the event
0169   for (std::size_t flav_n = 0; flav_n < flav_names_.size(); ++flav_n) {
0170     if (produceValueMap_) {
0171       for (size_t k = 0; k < output_scores.size(); k++) {
0172         std::unique_ptr<edm::ValueMap<float>> VM(new edm::ValueMap<float>());
0173         edm::ValueMap<float>::Filler filler(*VM);
0174         filler.insert(jets, output_scores.at(k).begin(), output_scores.at(k).end());
0175         filler.fill();
0176         iEvent.put(std::move(VM), flav_names_[k]);
0177       }
0178     }
0179     iEvent.put(std::move(output_tags[flav_n]), flav_names_[flav_n]);
0180   }
0181   data_.clear();
0182 }
0183 
0184 void DeepFlavourONNXJetTagsProducer::make_inputs(unsigned i_jet, const reco::DeepFlavourTagInfo& taginfo) {
0185   const auto& features = taginfo.features();
0186   float* ptr = nullptr;
0187   const float* start = nullptr;
0188   unsigned offset = 0;
0189 
0190   // jet and other global features
0191   offset = i_jet * input_sizes_[kGlobal];
0192   ptr = &data_[kGlobal][offset];
0193   // jet variables
0194   const auto& jet_features = features.jet_features;
0195   start = ptr;
0196   *ptr = jet_features.pt;
0197   *(++ptr) = jet_features.eta;
0198   // number of elements in different collections
0199   *(++ptr) = features.c_pf_features.size();
0200   *(++ptr) = features.n_pf_features.size();
0201   *(++ptr) = features.sv_features.size();
0202   *(++ptr) = features.npv;
0203   // variables from ShallowTagInfo
0204   const auto& tag_info_features = features.tag_info_features;
0205   *(++ptr) = tag_info_features.trackSumJetEtRatio;
0206   *(++ptr) = tag_info_features.trackSumJetDeltaR;
0207   *(++ptr) = tag_info_features.vertexCategory;
0208   *(++ptr) = tag_info_features.trackSip2dValAboveCharm;
0209   *(++ptr) = tag_info_features.trackSip2dSigAboveCharm;
0210   *(++ptr) = tag_info_features.trackSip3dValAboveCharm;
0211   *(++ptr) = tag_info_features.trackSip3dSigAboveCharm;
0212   *(++ptr) = tag_info_features.jetNSelectedTracks;
0213   *(++ptr) = tag_info_features.jetNTracksEtaRel;
0214   assert(start + n_features_global_ - 1 == ptr);
0215 
0216   // c_pf candidates
0217   auto max_c_pf_n = std::min(features.c_pf_features.size(), (std::size_t)25);
0218   offset = i_jet * input_sizes_[kChargedCandidates];
0219   for (std::size_t c_pf_n = 0; c_pf_n < max_c_pf_n; c_pf_n++) {
0220     const auto& c_pf_features = features.c_pf_features.at(c_pf_n);
0221     ptr = &data_[kChargedCandidates][offset + c_pf_n * n_features_cpf_];
0222     start = ptr;
0223     *ptr = c_pf_features.btagPf_trackEtaRel;
0224     *(++ptr) = c_pf_features.btagPf_trackPtRel;
0225     *(++ptr) = c_pf_features.btagPf_trackPPar;
0226     *(++ptr) = c_pf_features.btagPf_trackDeltaR;
0227     *(++ptr) = c_pf_features.btagPf_trackPParRatio;
0228     *(++ptr) = c_pf_features.btagPf_trackSip2dVal;
0229     *(++ptr) = c_pf_features.btagPf_trackSip2dSig;
0230     *(++ptr) = c_pf_features.btagPf_trackSip3dVal;
0231     *(++ptr) = c_pf_features.btagPf_trackSip3dSig;
0232     *(++ptr) = c_pf_features.btagPf_trackJetDistVal;
0233     *(++ptr) = c_pf_features.ptrel;
0234     *(++ptr) = c_pf_features.drminsv;
0235     *(++ptr) = c_pf_features.vtx_ass;
0236     *(++ptr) = c_pf_features.puppiw;
0237     *(++ptr) = c_pf_features.chi2;
0238     *(++ptr) = c_pf_features.quality;
0239     assert(start + n_features_cpf_ - 1 == ptr);
0240   }
0241 
0242   // n_pf candidates
0243   auto max_n_pf_n = std::min(features.n_pf_features.size(), (std::size_t)25);
0244   offset = i_jet * input_sizes_[kNeutralCandidates];
0245   for (std::size_t n_pf_n = 0; n_pf_n < max_n_pf_n; n_pf_n++) {
0246     const auto& n_pf_features = features.n_pf_features.at(n_pf_n);
0247     ptr = &data_[kNeutralCandidates][offset + n_pf_n * n_features_npf_];
0248     start = ptr;
0249     *ptr = n_pf_features.ptrel;
0250     *(++ptr) = n_pf_features.deltaR;
0251     *(++ptr) = n_pf_features.isGamma;
0252     *(++ptr) = n_pf_features.hadFrac;
0253     *(++ptr) = n_pf_features.drminsv;
0254     *(++ptr) = n_pf_features.puppiw;
0255     assert(start + n_features_npf_ - 1 == ptr);
0256   }
0257 
0258   // sv candidates
0259   auto max_sv_n = std::min(features.sv_features.size(), (std::size_t)4);
0260   offset = i_jet * input_sizes_[kVertices];
0261   for (std::size_t sv_n = 0; sv_n < max_sv_n; sv_n++) {
0262     const auto& sv_features = features.sv_features.at(sv_n);
0263     ptr = &data_[kVertices][offset + sv_n * n_features_sv_];
0264     start = ptr;
0265     *ptr = sv_features.pt;
0266     *(++ptr) = sv_features.deltaR;
0267     *(++ptr) = sv_features.mass;
0268     *(++ptr) = sv_features.ntracks;
0269     *(++ptr) = sv_features.chi2;
0270     *(++ptr) = sv_features.normchi2;
0271     *(++ptr) = sv_features.dxy;
0272     *(++ptr) = sv_features.dxysig;
0273     *(++ptr) = sv_features.d3d;
0274     *(++ptr) = sv_features.d3dsig;
0275     *(++ptr) = sv_features.costhetasvpv;
0276     *(++ptr) = sv_features.enratio;
0277     assert(start + n_features_sv_ - 1 == ptr);
0278   }
0279 
0280   // last input: jet pt
0281   offset = i_jet * input_sizes_[kJetPt];
0282   data_[kJetPt][offset] = features.jet_features.pt;
0283 }
0284 
0285 //define this as a plug-in
0286 DEFINE_FWK_MODULE(DeepFlavourONNXJetTagsProducer);