Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-06-30 23:17:13

0001 #include "FWCore/Framework/interface/Frameworkfwd.h"
0002 #include "FWCore/Framework/interface/stream/EDProducer.h"
0003 
0004 #include "FWCore/Framework/interface/Event.h"
0005 #include "FWCore/Framework/interface/MakerMacros.h"
0006 
0007 #include "FWCore/Framework/interface/makeRefToBaseProdFrom.h"
0008 
0009 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0010 #include "FWCore/Utilities/interface/StreamID.h"
0011 
0012 #include "DataFormats/BTauReco/interface/JetTag.h"
0013 
0014 #include "DataFormats/BTauReco/interface/DeepBoostedJetTagInfo.h"
0015 
0016 #include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
0017 
0018 #include "RecoBTag/FeatureTools/interface/deep_helpers.h"
0019 
0020 #include <iostream>
0021 #include <fstream>
0022 #include <algorithm>
0023 #include <numeric>
0024 #include <nlohmann/json.hpp>
0025 
0026 using namespace cms::Ort;
0027 using namespace btagbtvdeep;
0028 
0029 class BoostedJetONNXJetTagsProducer : public edm::stream::EDProducer<edm::GlobalCache<ONNXRuntime>> {
0030 public:
0031   explicit BoostedJetONNXJetTagsProducer(const edm::ParameterSet &, const ONNXRuntime *);
0032   ~BoostedJetONNXJetTagsProducer() override;
0033 
0034   static void fillDescriptions(edm::ConfigurationDescriptions &);
0035 
0036   static std::unique_ptr<ONNXRuntime> initializeGlobalCache(const edm::ParameterSet &);
0037   static void globalEndJob(const ONNXRuntime *);
0038 
0039 private:
0040   typedef std::vector<reco::DeepBoostedJetTagInfo> TagInfoCollection;
0041   typedef reco::JetTagCollection JetTagCollection;
0042 
0043   void produce(edm::Event &, const edm::EventSetup &) override;
0044 
0045   void make_inputs(const reco::DeepBoostedJetTagInfo &taginfo);
0046 
0047   const edm::EDGetTokenT<TagInfoCollection> src_;
0048   std::vector<std::string> flav_names_;             // names of the output scores
0049   std::vector<std::string> input_names_;            // names of each input group - the ordering is important!
0050   std::vector<std::vector<int64_t>> input_shapes_;  // shapes of each input group (-1 for dynamic axis)
0051   std::vector<unsigned> input_sizes_;               // total length of each input vector
0052   std::unordered_map<std::string, PreprocessParams> prep_info_map_;  // preprocessing info for each input group
0053 
0054   FloatArrays data_;
0055 
0056   bool debug_ = false;
0057 };
0058 
0059 BoostedJetONNXJetTagsProducer::BoostedJetONNXJetTagsProducer(const edm::ParameterSet &iConfig, const ONNXRuntime *cache)
0060     : src_(consumes<TagInfoCollection>(iConfig.getParameter<edm::InputTag>("src"))),
0061       flav_names_(iConfig.getParameter<std::vector<std::string>>("flav_names")),
0062       debug_(iConfig.getUntrackedParameter<bool>("debugMode", false)) {
0063   ParticleNetConstructor(iConfig, true, input_names_, prep_info_map_, input_shapes_, input_sizes_, &data_);
0064 
0065   if (debug_) {
0066     LogDebug("BoostedJetONNXJetTagsProducer") << "<BoostedJetONNXJetTagsProducer::produce>:" << std::endl;
0067     for (unsigned i = 0; i < input_names_.size(); ++i) {
0068       const auto &group_name = input_names_.at(i);
0069       if (!input_shapes_.empty()) {
0070         LogDebug("BoostedJetONNXJetTagsProducer") << group_name << "\nshapes: ";
0071         for (const auto &x : input_shapes_.at(i)) {
0072           LogDebug("BoostedJetONNXJetTagsProducer") << x << ", ";
0073         }
0074       }
0075       LogDebug("BoostedJetONNXJetTagsProducer") << "\nvariables: ";
0076       for (const auto &x : prep_info_map_.at(group_name).var_names) {
0077         LogDebug("BoostedJetONNXJetTagsProducer") << x << ", ";
0078       }
0079       LogDebug("BoostedJetONNXJetTagsProducer") << "\n";
0080     }
0081     LogDebug("BoostedJetONNXJetTagsProducer") << "flav_names: ";
0082     for (const auto &flav_name : flav_names_) {
0083       LogDebug("BoostedJetONNXJetTagsProducer") << flav_name << ", ";
0084     }
0085     LogDebug("BoostedJetONNXJetTagsProducer") << "\n";
0086   }
0087 
0088   // get output names from flav_names
0089   for (const auto &flav_name : flav_names_) {
0090     produces<JetTagCollection>(flav_name);
0091   }
0092 }
0093 
0094 BoostedJetONNXJetTagsProducer::~BoostedJetONNXJetTagsProducer() {}
0095 
0096 void BoostedJetONNXJetTagsProducer::fillDescriptions(edm::ConfigurationDescriptions &descriptions) {
0097   // pfDeepBoostedJetTags
0098   edm::ParameterSetDescription desc;
0099   desc.add<edm::InputTag>("src", edm::InputTag("pfDeepBoostedJetTagInfos"));
0100   desc.add<std::string>("preprocess_json", "");
0101   // `preprocessParams` is deprecated -- use the preprocessing json file instead
0102   edm::ParameterSetDescription preprocessParams;
0103   preprocessParams.setAllowAnything();
0104   preprocessParams.setComment("`preprocessParams` is deprecated, please use `preprocess_json` instead.");
0105   desc.addOptional<edm::ParameterSetDescription>("preprocessParams", preprocessParams);
0106   desc.add<edm::FileInPath>("model_path",
0107                             edm::FileInPath("RecoBTag/Combined/data/DeepBoostedJet/V02/full/resnet.onnx"));
0108   desc.add<std::vector<std::string>>("flav_names",
0109                                      std::vector<std::string>{
0110                                          "probTbcq",
0111                                          "probTbqq",
0112                                          "probTbc",
0113                                          "probTbq",
0114                                          "probWcq",
0115                                          "probWqq",
0116                                          "probZbb",
0117                                          "probZcc",
0118                                          "probZqq",
0119                                          "probHbb",
0120                                          "probHcc",
0121                                          "probHqqqq",
0122                                          "probQCDbb",
0123                                          "probQCDcc",
0124                                          "probQCDb",
0125                                          "probQCDc",
0126                                          "probQCDothers",
0127                                      });
0128   desc.addOptionalUntracked<bool>("debugMode", false);
0129 
0130   descriptions.addWithDefaultLabel(desc);
0131 }
0132 
0133 std::unique_ptr<ONNXRuntime> BoostedJetONNXJetTagsProducer::initializeGlobalCache(const edm::ParameterSet &iConfig) {
0134   return std::make_unique<ONNXRuntime>(iConfig.getParameter<edm::FileInPath>("model_path").fullPath());
0135 }
0136 
0137 void BoostedJetONNXJetTagsProducer::globalEndJob(const ONNXRuntime *cache) {}
0138 
0139 void BoostedJetONNXJetTagsProducer::produce(edm::Event &iEvent, const edm::EventSetup &iSetup) {
0140   edm::Handle<TagInfoCollection> tag_infos;
0141   iEvent.getByToken(src_, tag_infos);
0142 
0143   // initialize output collection
0144   std::vector<std::unique_ptr<JetTagCollection>> output_tags;
0145   if (!tag_infos->empty()) {
0146     auto jet_ref = tag_infos->begin()->jet();
0147     auto ref2prod = edm::makeRefToBaseProdFrom(jet_ref, iEvent);
0148     for (std::size_t i = 0; i < flav_names_.size(); i++) {
0149       output_tags.emplace_back(std::make_unique<JetTagCollection>(ref2prod));
0150     }
0151   } else {
0152     for (std::size_t i = 0; i < flav_names_.size(); i++) {
0153       output_tags.emplace_back(std::make_unique<JetTagCollection>());
0154     }
0155   }
0156 
0157   for (unsigned jet_n = 0; jet_n < tag_infos->size(); ++jet_n) {
0158     const auto &taginfo = (*tag_infos)[jet_n];
0159     std::vector<float> outputs(flav_names_.size(), 0);  // init as all zeros
0160 
0161     if (!taginfo.features().empty()) {
0162       // convert inputs
0163       make_inputs(taginfo);
0164       // run prediction and get outputs
0165       outputs = globalCache()->run(input_names_, data_, input_shapes_)[0];
0166       assert(outputs.size() == flav_names_.size());
0167     }
0168 
0169     const auto &jet_ref = tag_infos->at(jet_n).jet();
0170     for (std::size_t flav_n = 0; flav_n < flav_names_.size(); flav_n++) {
0171       (*(output_tags[flav_n]))[jet_ref] = outputs[flav_n];
0172     }
0173   }
0174 
0175   if (debug_) {
0176     LogDebug("produce") << "<BoostedJetONNXJetTagsProducer::produce>:" << std::endl
0177                         << "=== " << iEvent.id().run() << ":" << iEvent.id().luminosityBlock() << ":"
0178                         << iEvent.id().event() << " ===" << std::endl;
0179     for (unsigned jet_n = 0; jet_n < tag_infos->size(); ++jet_n) {
0180       const auto &jet_ref = tag_infos->at(jet_n).jet();
0181       LogDebug("produce") << " - Jet #" << jet_n << ", pt=" << jet_ref->pt() << ", eta=" << jet_ref->eta()
0182                           << ", phi=" << jet_ref->phi() << std::endl;
0183       for (std::size_t flav_n = 0; flav_n < flav_names_.size(); ++flav_n) {
0184         LogDebug("produce") << "    " << flav_names_.at(flav_n) << " = " << (*(output_tags.at(flav_n)))[jet_ref]
0185                             << std::endl;
0186       }
0187     }
0188   }
0189 
0190   // put into the event
0191   for (std::size_t flav_n = 0; flav_n < flav_names_.size(); ++flav_n) {
0192     iEvent.put(std::move(output_tags[flav_n]), flav_names_[flav_n]);
0193   }
0194 }
0195 
0196 void BoostedJetONNXJetTagsProducer::make_inputs(const reco::DeepBoostedJetTagInfo &taginfo) {
0197   for (unsigned igroup = 0; igroup < input_names_.size(); ++igroup) {
0198     const auto &group_name = input_names_[igroup];
0199     const auto &prep_params = prep_info_map_.at(group_name);
0200     auto &group_values = data_[igroup];
0201     group_values.resize(input_sizes_[igroup]);
0202     // first reset group_values to 0
0203     std::fill(group_values.begin(), group_values.end(), 0);
0204     unsigned curr_pos = 0;
0205     // transform/pad
0206     for (unsigned i = 0; i < prep_params.var_names.size(); ++i) {
0207       const auto &varname = prep_params.var_names[i];
0208       const auto &raw_value = taginfo.features().get(varname);
0209       const auto &info = prep_params.info(varname);
0210       int insize = center_norm_pad(raw_value,
0211                                    info.center,
0212                                    info.norm_factor,
0213                                    prep_params.min_length,
0214                                    prep_params.max_length,
0215                                    group_values,
0216                                    curr_pos,
0217                                    info.pad,
0218                                    info.replace_inf_value,
0219                                    info.lower_bound,
0220                                    info.upper_bound);
0221       curr_pos += insize;
0222       if (i == 0 && (!input_shapes_.empty())) {
0223         input_shapes_[igroup][2] = insize;
0224       }
0225 
0226       if (debug_) {
0227         LogDebug("make_inputs") << "<BoostedJetONNXJetTagsProducer::make_inputs>:" << std::endl
0228                                 << " -- var=" << varname << ", center=" << info.center << ", scale=" << info.norm_factor
0229                                 << ", replace=" << info.replace_inf_value << ", pad=" << info.pad << std::endl;
0230         for (unsigned i = curr_pos - insize; i < curr_pos; i++) {
0231           LogDebug("make_inputs") << group_values[i] << ",";
0232         }
0233         LogDebug("make_inputs") << std::endl;
0234       }
0235     }
0236     group_values.resize(curr_pos);
0237   }
0238 }
0239 
0240 //define this as a plug-in
0241 DEFINE_FWK_MODULE(BoostedJetONNXJetTagsProducer);