Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:28

0001 #include "FWCore/Framework/interface/Frameworkfwd.h"
0002 #include "FWCore/Framework/interface/stream/EDProducer.h"
0003 
0004 #include "FWCore/Framework/interface/Event.h"
0005 #include "FWCore/Framework/interface/MakerMacros.h"
0006 
0007 #include "FWCore/Framework/interface/makeRefToBaseProdFrom.h"
0008 
0009 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0010 #include "FWCore/Utilities/interface/StreamID.h"
0011 
0012 #include "DataFormats/BTauReco/interface/JetTag.h"
0013 
0014 #include "DataFormats/BTauReco/interface/DeepBoostedJetTagInfo.h"
0015 
0016 #include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
0017 
0018 #include "RecoBTag/FeatureTools/interface/deep_helpers.h"
0019 
0020 #include <iostream>
0021 #include <fstream>
0022 #include <algorithm>
0023 #include <numeric>
0024 #include <nlohmann/json.hpp>
0025 
0026 using namespace cms::Ort;
0027 using namespace btagbtvdeep;
0028 
0029 class BoostedJetONNXJetTagsProducer : public edm::stream::EDProducer<edm::GlobalCache<ONNXRuntime>> {
0030 public:
0031   explicit BoostedJetONNXJetTagsProducer(const edm::ParameterSet &, const ONNXRuntime *);
0032   ~BoostedJetONNXJetTagsProducer() override;
0033 
0034   static void fillDescriptions(edm::ConfigurationDescriptions &);
0035 
0036   static std::unique_ptr<ONNXRuntime> initializeGlobalCache(const edm::ParameterSet &);
0037   static void globalEndJob(const ONNXRuntime *);
0038 
0039 private:
0040   typedef std::vector<reco::DeepBoostedJetTagInfo> TagInfoCollection;
0041   typedef reco::JetTagCollection JetTagCollection;
0042 
0043   void produce(edm::Event &, const edm::EventSetup &) override;
0044 
0045   void make_inputs(const reco::DeepBoostedJetTagInfo &taginfo);
0046 
0047   const edm::EDGetTokenT<TagInfoCollection> src_;
0048   std::vector<std::string> flav_names_;               // names of the output scores
0049   edm::EDGetTokenT<edm::View<reco::Jet>> jet_token_;  // jets if function produces a ValueMap
0050   std::vector<std::string> input_names_;              // names of each input group - the ordering is important!
0051   std::vector<std::vector<int64_t>> input_shapes_;    // shapes of each input group (-1 for dynamic axis)
0052   std::vector<unsigned> input_sizes_;                 // total length of each input vector
0053   std::unordered_map<std::string, PreprocessParams> prep_info_map_;  // preprocessing info for each input group
0054 
0055   FloatArrays data_;
0056 
0057   bool debug_ = false;
0058   bool produceValueMap_;
0059   edm::Handle<edm::View<reco::Jet>> jets;
0060 };
0061 
0062 BoostedJetONNXJetTagsProducer::BoostedJetONNXJetTagsProducer(const edm::ParameterSet &iConfig, const ONNXRuntime *cache)
0063     : src_(consumes<TagInfoCollection>(iConfig.getParameter<edm::InputTag>("src"))),
0064       flav_names_(iConfig.getParameter<std::vector<std::string>>("flav_names")),
0065       debug_(iConfig.getUntrackedParameter<bool>("debugMode", false)),
0066       produceValueMap_(iConfig.getUntrackedParameter<bool>("produceValueMap", false)) {
0067   if (produceValueMap_) {
0068     jet_token_ = consumes<edm::View<reco::Jet>>(iConfig.getParameter<edm::InputTag>("jets"));
0069   }
0070 
0071   ParticleNetConstructor(iConfig, true, input_names_, prep_info_map_, input_shapes_, input_sizes_, &data_);
0072 
0073   if (debug_) {
0074     LogDebug("BoostedJetONNXJetTagsProducer") << "<BoostedJetONNXJetTagsProducer::produce>:" << std::endl;
0075     for (unsigned i = 0; i < input_names_.size(); ++i) {
0076       const auto &group_name = input_names_.at(i);
0077       if (!input_shapes_.empty()) {
0078         LogDebug("BoostedJetONNXJetTagsProducer") << group_name << "\nshapes: ";
0079         for (const auto &x : input_shapes_.at(i)) {
0080           LogDebug("BoostedJetONNXJetTagsProducer") << x << ", ";
0081         }
0082       }
0083       LogDebug("BoostedJetONNXJetTagsProducer") << "\nvariables: ";
0084       for (const auto &x : prep_info_map_.at(group_name).var_names) {
0085         LogDebug("BoostedJetONNXJetTagsProducer") << x << ", ";
0086       }
0087       LogDebug("BoostedJetONNXJetTagsProducer") << "\n";
0088     }
0089     LogDebug("BoostedJetONNXJetTagsProducer") << "flav_names: ";
0090     for (const auto &flav_name : flav_names_) {
0091       LogDebug("BoostedJetONNXJetTagsProducer") << flav_name << ", ";
0092     }
0093     LogDebug("BoostedJetONNXJetTagsProducer") << "\n";
0094   }
0095 
0096   // get output names from flav_names
0097   for (const auto &flav_name : flav_names_) {
0098     if (!produceValueMap_) {
0099       produces<JetTagCollection>(flav_name);
0100     } else {
0101       produces<edm::ValueMap<float>>(flav_name);
0102     }
0103   }
0104 }
0105 
0106 BoostedJetONNXJetTagsProducer::~BoostedJetONNXJetTagsProducer() {}
0107 
0108 void BoostedJetONNXJetTagsProducer::fillDescriptions(edm::ConfigurationDescriptions &descriptions) {
0109   // pfDeepBoostedJetTags
0110   edm::ParameterSetDescription desc;
0111   desc.add<edm::InputTag>("src", edm::InputTag("pfDeepBoostedJetTagInfos"));
0112   desc.add<std::string>("preprocess_json", "");
0113   // `preprocessParams` is deprecated -- use the preprocessing json file instead
0114   edm::ParameterSetDescription preprocessParams;
0115   preprocessParams.setAllowAnything();
0116   preprocessParams.setComment("`preprocessParams` is deprecated, please use `preprocess_json` instead.");
0117   desc.addOptional<edm::ParameterSetDescription>("preprocessParams", preprocessParams);
0118   desc.add<edm::FileInPath>("model_path",
0119                             edm::FileInPath("RecoBTag/Combined/data/DeepBoostedJet/V02/full/resnet.onnx"));
0120   desc.add<std::vector<std::string>>("flav_names",
0121                                      std::vector<std::string>{
0122                                          "probTbcq",
0123                                          "probTbqq",
0124                                          "probTbc",
0125                                          "probTbq",
0126                                          "probWcq",
0127                                          "probWqq",
0128                                          "probZbb",
0129                                          "probZcc",
0130                                          "probZqq",
0131                                          "probHbb",
0132                                          "probHcc",
0133                                          "probHqqqq",
0134                                          "probQCDbb",
0135                                          "probQCDcc",
0136                                          "probQCDb",
0137                                          "probQCDc",
0138                                          "probQCDothers",
0139                                      });
0140   desc.add<edm::InputTag>("jets", edm::InputTag(""));
0141   desc.addOptionalUntracked<bool>("produceValueMap", false);
0142   desc.addOptionalUntracked<bool>("debugMode", false);
0143 
0144   descriptions.addWithDefaultLabel(desc);
0145 }
0146 
0147 std::unique_ptr<ONNXRuntime> BoostedJetONNXJetTagsProducer::initializeGlobalCache(const edm::ParameterSet &iConfig) {
0148   return std::make_unique<ONNXRuntime>(iConfig.getParameter<edm::FileInPath>("model_path").fullPath());
0149 }
0150 
0151 void BoostedJetONNXJetTagsProducer::globalEndJob(const ONNXRuntime *cache) {}
0152 
0153 void BoostedJetONNXJetTagsProducer::produce(edm::Event &iEvent, const edm::EventSetup &iSetup) {
0154   edm::Handle<TagInfoCollection> tag_infos;
0155   iEvent.getByToken(src_, tag_infos);
0156   if (produceValueMap_) {
0157     jets = iEvent.getHandle(jet_token_);
0158   }
0159 
0160   // initialize output collection
0161   std::vector<std::unique_ptr<JetTagCollection>> output_tags;
0162   std::vector<std::vector<float>> output_scores(flav_names_.size(), std::vector<float>(tag_infos->size(), -1.0));
0163   if (!tag_infos->empty()) {
0164     auto jet_ref = tag_infos->begin()->jet();
0165     auto ref2prod = edm::makeRefToBaseProdFrom(jet_ref, iEvent);
0166     for (std::size_t i = 0; i < flav_names_.size(); i++) {
0167       output_tags.emplace_back(std::make_unique<JetTagCollection>(ref2prod));
0168     }
0169   } else {
0170     for (std::size_t i = 0; i < flav_names_.size(); i++) {
0171       output_tags.emplace_back(std::make_unique<JetTagCollection>());
0172     }
0173   }
0174 
0175   for (unsigned jet_n = 0; jet_n < tag_infos->size(); ++jet_n) {
0176     const auto &taginfo = (*tag_infos)[jet_n];
0177     std::vector<float> outputs(flav_names_.size(), 0);  // init as all zeros
0178 
0179     if (!taginfo.features().empty()) {
0180       // convert inputs
0181       make_inputs(taginfo);
0182       // run prediction and get outputs
0183       outputs = globalCache()->run(input_names_, data_, input_shapes_)[0];
0184       assert(outputs.size() == flav_names_.size());
0185     }
0186 
0187     const auto &jet_ref = tag_infos->at(jet_n).jet();
0188     for (std::size_t flav_n = 0; flav_n < flav_names_.size(); flav_n++) {
0189       (*(output_tags[flav_n]))[jet_ref] = outputs[flav_n];
0190       output_scores[flav_n][jet_n] = outputs[flav_n];
0191     }
0192   }
0193 
0194   if (debug_) {
0195     LogDebug("produce") << "<BoostedJetONNXJetTagsProducer::produce>:" << std::endl
0196                         << "=== " << iEvent.id().run() << ":" << iEvent.id().luminosityBlock() << ":"
0197                         << iEvent.id().event() << " ===" << std::endl;
0198     for (unsigned jet_n = 0; jet_n < tag_infos->size(); ++jet_n) {
0199       const auto &jet_ref = tag_infos->at(jet_n).jet();
0200       LogDebug("produce") << " - Jet #" << jet_n << ", pt=" << jet_ref->pt() << ", eta=" << jet_ref->eta()
0201                           << ", phi=" << jet_ref->phi() << std::endl;
0202       for (std::size_t flav_n = 0; flav_n < flav_names_.size(); ++flav_n) {
0203         if (!produceValueMap_) {
0204           LogDebug("produce") << "    " << flav_names_.at(flav_n) << " = " << (*(output_tags.at(flav_n)))[jet_ref]
0205                               << std::endl;
0206         } else {
0207           LogDebug("produce") << "    " << flav_names_.at(flav_n) << " = " << output_scores[flav_n][jet_n] << std::endl;
0208         }
0209       }
0210     }
0211   }
0212 
0213   // put into the event
0214   if (!produceValueMap_) {
0215     for (std::size_t flav_n = 0; flav_n < flav_names_.size(); ++flav_n) {
0216       iEvent.put(std::move(output_tags[flav_n]), flav_names_[flav_n]);
0217     }
0218   } else {
0219     for (size_t k = 0; k < output_scores.size(); k++) {
0220       std::unique_ptr<edm::ValueMap<float>> VM(new edm::ValueMap<float>());
0221       edm::ValueMap<float>::Filler filler(*VM);
0222       filler.insert(jets, output_scores.at(k).begin(), output_scores.at(k).end());
0223       filler.fill();
0224       iEvent.put(std::move(VM), flav_names_[k]);
0225     }
0226   }
0227 }
0228 void BoostedJetONNXJetTagsProducer::make_inputs(const reco::DeepBoostedJetTagInfo &taginfo) {
0229   for (unsigned igroup = 0; igroup < input_names_.size(); ++igroup) {
0230     const auto &group_name = input_names_[igroup];
0231     const auto &prep_params = prep_info_map_.at(group_name);
0232     auto &group_values = data_[igroup];
0233     group_values.resize(input_sizes_[igroup]);
0234     // first reset group_values to 0
0235     std::fill(group_values.begin(), group_values.end(), 0);
0236     unsigned curr_pos = 0;
0237     // transform/pad
0238     for (unsigned i = 0; i < prep_params.var_names.size(); ++i) {
0239       const auto &varname = prep_params.var_names[i];
0240       const auto &raw_value = taginfo.features().get(varname);
0241       const auto &info = prep_params.info(varname);
0242       int insize = center_norm_pad(raw_value,
0243                                    info.center,
0244                                    info.norm_factor,
0245                                    prep_params.min_length,
0246                                    prep_params.max_length,
0247                                    group_values,
0248                                    curr_pos,
0249                                    info.pad,
0250                                    info.replace_inf_value,
0251                                    info.lower_bound,
0252                                    info.upper_bound);
0253       curr_pos += insize;
0254       if (i == 0 && (!input_shapes_.empty())) {
0255         input_shapes_[igroup][2] = insize;
0256       }
0257 
0258       if (debug_) {
0259         LogDebug("make_inputs") << "<BoostedJetONNXJetTagsProducer::make_inputs>:" << std::endl
0260                                 << " -- var=" << varname << ", center=" << info.center << ", scale=" << info.norm_factor
0261                                 << ", replace=" << info.replace_inf_value << ", pad=" << info.pad << std::endl;
0262         for (unsigned i = curr_pos - insize; i < curr_pos; i++) {
0263           LogDebug("make_inputs") << group_values[i] << ",";
0264         }
0265         LogDebug("make_inputs") << std::endl;
0266       }
0267     }
0268     group_values.resize(curr_pos);
0269   }
0270 }
0271 
0272 //define this as a plug-in
0273 DEFINE_FWK_MODULE(BoostedJetONNXJetTagsProducer);