Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-02-14 13:33:40

0001 #include "FWCore/Framework/interface/Frameworkfwd.h"
0002 #include "FWCore/Framework/interface/stream/EDProducer.h"
0003 
0004 #include "FWCore/Framework/interface/Event.h"
0005 #include "FWCore/Framework/interface/MakerMacros.h"
0006 
0007 #include "FWCore/Framework/interface/makeRefToBaseProdFrom.h"
0008 
0009 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0010 #include "FWCore/Utilities/interface/StreamID.h"
0011 
0012 #include "DataFormats/BTauReco/interface/JetTag.h"
0013 
0014 #include "DataFormats/BTauReco/interface/DeepBoostedJetTagInfo.h"
0015 
0016 #include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
0017 
0018 #include "RecoBTag/FeatureTools/interface/deep_helpers.h"
0019 
0020 #include <iostream>
0021 #include <fstream>
0022 #include <algorithm>
0023 #include <numeric>
0024 #include <nlohmann/json.hpp>
0025 
0026 using namespace cms::Ort;
0027 using namespace btagbtvdeep;
0028 
0029 class BoostedJetONNXJetTagsProducer : public edm::stream::EDProducer<edm::GlobalCache<ONNXRuntime>> {
0030 public:
0031   explicit BoostedJetONNXJetTagsProducer(const edm::ParameterSet &, const ONNXRuntime *);
0032   ~BoostedJetONNXJetTagsProducer() override;
0033 
0034   static void fillDescriptions(edm::ConfigurationDescriptions &);
0035 
0036   static std::unique_ptr<ONNXRuntime> initializeGlobalCache(const edm::ParameterSet &);
0037   static void globalEndJob(const ONNXRuntime *);
0038 
0039 private:
0040   typedef std::vector<reco::DeepBoostedJetTagInfo> TagInfoCollection;
0041   typedef reco::JetTagCollection JetTagCollection;
0042 
0043   void produce(edm::Event &, const edm::EventSetup &) override;
0044 
0045   std::vector<float> center_norm_pad(const std::vector<float> &input,
0046                                      float center,
0047                                      float scale,
0048                                      unsigned min_length,
0049                                      unsigned max_length,
0050                                      float pad_value = 0,
0051                                      float replace_inf_value = 0,
0052                                      float min = 0,
0053                                      float max = -1);
0054   void make_inputs(const reco::DeepBoostedJetTagInfo &taginfo);
0055 
0056   const edm::EDGetTokenT<TagInfoCollection> src_;
0057   std::vector<std::string> flav_names_;             // names of the output scores
0058   std::vector<std::string> input_names_;            // names of each input group - the ordering is important!
0059   std::vector<std::vector<int64_t>> input_shapes_;  // shapes of each input group (-1 for dynamic axis)
0060   std::vector<unsigned> input_sizes_;               // total length of each input vector
0061   std::unordered_map<std::string, PreprocessParams> prep_info_map_;  // preprocessing info for each input group
0062 
0063   FloatArrays data_;
0064 
0065   bool debug_ = false;
0066 };
0067 
0068 BoostedJetONNXJetTagsProducer::BoostedJetONNXJetTagsProducer(const edm::ParameterSet &iConfig, const ONNXRuntime *cache)
0069     : src_(consumes<TagInfoCollection>(iConfig.getParameter<edm::InputTag>("src"))),
0070       flav_names_(iConfig.getParameter<std::vector<std::string>>("flav_names")),
0071       debug_(iConfig.getUntrackedParameter<bool>("debugMode", false)) {
0072   // load preprocessing info
0073   auto json_path = iConfig.getParameter<std::string>("preprocess_json");
0074   if (!json_path.empty()) {
0075     // use preprocessing json file if available
0076     std::ifstream ifs(edm::FileInPath(json_path).fullPath());
0077     nlohmann::json js = nlohmann::json::parse(ifs);
0078     js.at("input_names").get_to(input_names_);
0079     for (const auto &group_name : input_names_) {
0080       const auto &group_pset = js.at(group_name);
0081       auto &prep_params = prep_info_map_[group_name];
0082       group_pset.at("var_names").get_to(prep_params.var_names);
0083       if (group_pset.contains("var_length")) {
0084         prep_params.min_length = group_pset.at("var_length");
0085         prep_params.max_length = prep_params.min_length;
0086       } else {
0087         prep_params.min_length = group_pset.at("min_length");
0088         prep_params.max_length = group_pset.at("max_length");
0089         input_shapes_.push_back({1, (int64_t)prep_params.var_names.size(), -1});
0090       }
0091       const auto &var_info_pset = group_pset.at("var_infos");
0092       for (const auto &var_name : prep_params.var_names) {
0093         const auto &var_pset = var_info_pset.at(var_name);
0094         double median = var_pset.at("median");
0095         double norm_factor = var_pset.at("norm_factor");
0096         double replace_inf_value = var_pset.at("replace_inf_value");
0097         double lower_bound = var_pset.at("lower_bound");
0098         double upper_bound = var_pset.at("upper_bound");
0099         double pad = var_pset.contains("pad") ? double(var_pset.at("pad")) : 0;
0100         prep_params.var_info_map[var_name] =
0101             PreprocessParams::VarInfo(median, norm_factor, replace_inf_value, lower_bound, upper_bound, pad);
0102       }
0103 
0104       // create data storage with a fixed size vector initilized w/ 0
0105       const auto &len = input_sizes_.emplace_back(prep_params.max_length * prep_params.var_names.size());
0106       data_.emplace_back(len, 0);
0107     }
0108   } else {
0109     // otherwise use the PSet in the python config file
0110     const auto &prep_pset = iConfig.getParameterSet("preprocessParams");
0111     input_names_ = prep_pset.getParameter<std::vector<std::string>>("input_names");
0112     for (const auto &group_name : input_names_) {
0113       const auto &group_pset = prep_pset.getParameterSet(group_name);
0114       auto &prep_params = prep_info_map_[group_name];
0115       prep_params.var_names = group_pset.getParameter<std::vector<std::string>>("var_names");
0116       prep_params.min_length = group_pset.getParameter<unsigned>("var_length");
0117       prep_params.max_length = prep_params.min_length;
0118       const auto &var_info_pset = group_pset.getParameterSet("var_infos");
0119       for (const auto &var_name : prep_params.var_names) {
0120         const auto &var_pset = var_info_pset.getParameterSet(var_name);
0121         double median = var_pset.getParameter<double>("median");
0122         double norm_factor = var_pset.getParameter<double>("norm_factor");
0123         double replace_inf_value = var_pset.getParameter<double>("replace_inf_value");
0124         double lower_bound = var_pset.getParameter<double>("lower_bound");
0125         double upper_bound = var_pset.getParameter<double>("upper_bound");
0126         prep_params.var_info_map[var_name] =
0127             PreprocessParams::VarInfo(median, norm_factor, replace_inf_value, lower_bound, upper_bound, 0);
0128       }
0129 
0130       // create data storage with a fixed size vector initiliazed w/ 0
0131       const auto &len = input_sizes_.emplace_back(prep_params.max_length * prep_params.var_names.size());
0132       data_.emplace_back(len, 0);
0133     }
0134   }
0135 
0136   if (debug_) {
0137     for (unsigned i = 0; i < input_names_.size(); ++i) {
0138       const auto &group_name = input_names_.at(i);
0139       if (!input_shapes_.empty()) {
0140         std::cout << group_name << "\nshapes: ";
0141         for (const auto &x : input_shapes_.at(i)) {
0142           std::cout << x << ", ";
0143         }
0144       }
0145       std::cout << "\nvariables: ";
0146       for (const auto &x : prep_info_map_.at(group_name).var_names) {
0147         std::cout << x << ", ";
0148       }
0149       std::cout << "\n";
0150     }
0151     std::cout << "flav_names: ";
0152     for (const auto &flav_name : flav_names_) {
0153       std::cout << flav_name << ", ";
0154     }
0155     std::cout << "\n";
0156   }
0157 
0158   // get output names from flav_names
0159   for (const auto &flav_name : flav_names_) {
0160     produces<JetTagCollection>(flav_name);
0161   }
0162 }
0163 
0164 BoostedJetONNXJetTagsProducer::~BoostedJetONNXJetTagsProducer() {}
0165 
0166 void BoostedJetONNXJetTagsProducer::fillDescriptions(edm::ConfigurationDescriptions &descriptions) {
0167   // pfDeepBoostedJetTags
0168   edm::ParameterSetDescription desc;
0169   desc.add<edm::InputTag>("src", edm::InputTag("pfDeepBoostedJetTagInfos"));
0170   desc.add<std::string>("preprocess_json", "");
0171   // `preprocessParams` is deprecated -- use the preprocessing json file instead
0172   edm::ParameterSetDescription preprocessParams;
0173   preprocessParams.setAllowAnything();
0174   preprocessParams.setComment("`preprocessParams` is deprecated, please use `preprocess_json` instead.");
0175   desc.addOptional<edm::ParameterSetDescription>("preprocessParams", preprocessParams);
0176   desc.add<edm::FileInPath>("model_path",
0177                             edm::FileInPath("RecoBTag/Combined/data/DeepBoostedJet/V02/full/resnet.onnx"));
0178   desc.add<std::vector<std::string>>("flav_names",
0179                                      std::vector<std::string>{
0180                                          "probTbcq",
0181                                          "probTbqq",
0182                                          "probTbc",
0183                                          "probTbq",
0184                                          "probWcq",
0185                                          "probWqq",
0186                                          "probZbb",
0187                                          "probZcc",
0188                                          "probZqq",
0189                                          "probHbb",
0190                                          "probHcc",
0191                                          "probHqqqq",
0192                                          "probQCDbb",
0193                                          "probQCDcc",
0194                                          "probQCDb",
0195                                          "probQCDc",
0196                                          "probQCDothers",
0197                                      });
0198   desc.addOptionalUntracked<bool>("debugMode", false);
0199 
0200   descriptions.addWithDefaultLabel(desc);
0201 }
0202 
0203 std::unique_ptr<ONNXRuntime> BoostedJetONNXJetTagsProducer::initializeGlobalCache(const edm::ParameterSet &iConfig) {
0204   return std::make_unique<ONNXRuntime>(iConfig.getParameter<edm::FileInPath>("model_path").fullPath());
0205 }
0206 
0207 void BoostedJetONNXJetTagsProducer::globalEndJob(const ONNXRuntime *cache) {}
0208 
0209 void BoostedJetONNXJetTagsProducer::produce(edm::Event &iEvent, const edm::EventSetup &iSetup) {
0210   edm::Handle<TagInfoCollection> tag_infos;
0211   iEvent.getByToken(src_, tag_infos);
0212 
0213   // initialize output collection
0214   std::vector<std::unique_ptr<JetTagCollection>> output_tags;
0215   if (!tag_infos->empty()) {
0216     auto jet_ref = tag_infos->begin()->jet();
0217     auto ref2prod = edm::makeRefToBaseProdFrom(jet_ref, iEvent);
0218     for (std::size_t i = 0; i < flav_names_.size(); i++) {
0219       output_tags.emplace_back(std::make_unique<JetTagCollection>(ref2prod));
0220     }
0221   } else {
0222     for (std::size_t i = 0; i < flav_names_.size(); i++) {
0223       output_tags.emplace_back(std::make_unique<JetTagCollection>());
0224     }
0225   }
0226 
0227   for (unsigned jet_n = 0; jet_n < tag_infos->size(); ++jet_n) {
0228     const auto &taginfo = (*tag_infos)[jet_n];
0229     std::vector<float> outputs(flav_names_.size(), 0);  // init as all zeros
0230 
0231     if (!taginfo.features().empty()) {
0232       // convert inputs
0233       make_inputs(taginfo);
0234       // run prediction and get outputs
0235       outputs = globalCache()->run(input_names_, data_, input_shapes_)[0];
0236       assert(outputs.size() == flav_names_.size());
0237     }
0238 
0239     const auto &jet_ref = tag_infos->at(jet_n).jet();
0240     for (std::size_t flav_n = 0; flav_n < flav_names_.size(); flav_n++) {
0241       (*(output_tags[flav_n]))[jet_ref] = outputs[flav_n];
0242     }
0243   }
0244 
0245   if (debug_) {
0246     std::cout << "=== " << iEvent.id().run() << ":" << iEvent.id().luminosityBlock() << ":" << iEvent.id().event()
0247               << " ===" << std::endl;
0248     for (unsigned jet_n = 0; jet_n < tag_infos->size(); ++jet_n) {
0249       const auto &jet_ref = tag_infos->at(jet_n).jet();
0250       std::cout << " - Jet #" << jet_n << ", pt=" << jet_ref->pt() << ", eta=" << jet_ref->eta()
0251                 << ", phi=" << jet_ref->phi() << std::endl;
0252       for (std::size_t flav_n = 0; flav_n < flav_names_.size(); ++flav_n) {
0253         std::cout << "    " << flav_names_.at(flav_n) << " = " << (*(output_tags.at(flav_n)))[jet_ref] << std::endl;
0254       }
0255     }
0256   }
0257 
0258   // put into the event
0259   for (std::size_t flav_n = 0; flav_n < flav_names_.size(); ++flav_n) {
0260     iEvent.put(std::move(output_tags[flav_n]), flav_names_[flav_n]);
0261   }
0262 }
0263 
0264 std::vector<float> BoostedJetONNXJetTagsProducer::center_norm_pad(const std::vector<float> &input,
0265                                                                   float center,
0266                                                                   float norm_factor,
0267                                                                   unsigned min_length,
0268                                                                   unsigned max_length,
0269                                                                   float pad_value,
0270                                                                   float replace_inf_value,
0271                                                                   float min,
0272                                                                   float max) {
0273   // do variable shifting/scaling/padding/clipping in one go
0274 
0275   assert(min <= pad_value && pad_value <= max);
0276   assert(min_length <= max_length);
0277 
0278   unsigned target_length = std::clamp((unsigned)input.size(), min_length, max_length);
0279   std::vector<float> out(target_length, pad_value);
0280   for (unsigned i = 0; i < input.size() && i < target_length; ++i) {
0281     out[i] = std::clamp((catch_infs(input[i], replace_inf_value) - center) * norm_factor, min, max);
0282   }
0283   return out;
0284 }
0285 
0286 void BoostedJetONNXJetTagsProducer::make_inputs(const reco::DeepBoostedJetTagInfo &taginfo) {
0287   for (unsigned igroup = 0; igroup < input_names_.size(); ++igroup) {
0288     const auto &group_name = input_names_[igroup];
0289     const auto &prep_params = prep_info_map_.at(group_name);
0290     auto &group_values = data_[igroup];
0291     group_values.resize(input_sizes_[igroup]);
0292     // first reset group_values to 0
0293     std::fill(group_values.begin(), group_values.end(), 0);
0294     unsigned curr_pos = 0;
0295     // transform/pad
0296     for (unsigned i = 0; i < prep_params.var_names.size(); ++i) {
0297       const auto &varname = prep_params.var_names[i];
0298       const auto &raw_value = taginfo.features().get(varname);
0299       const auto &info = prep_params.info(varname);
0300       auto val = center_norm_pad(raw_value,
0301                                  info.center,
0302                                  info.norm_factor,
0303                                  prep_params.min_length,
0304                                  prep_params.max_length,
0305                                  info.pad,
0306                                  info.replace_inf_value,
0307                                  info.lower_bound,
0308                                  info.upper_bound);
0309       std::copy(val.begin(), val.end(), group_values.begin() + curr_pos);
0310       curr_pos += val.size();
0311       if (i == 0 && (!input_shapes_.empty())) {
0312         input_shapes_[igroup][2] = val.size();
0313       }
0314 
0315       if (debug_) {
0316         std::cout << " -- var=" << varname << ", center=" << info.center << ", scale=" << info.norm_factor
0317                   << ", replace=" << info.replace_inf_value << ", pad=" << info.pad << std::endl;
0318         for (const auto &v : val) {
0319           std::cout << v << ",";
0320         }
0321         std::cout << std::endl;
0322       }
0323     }
0324     group_values.resize(curr_pos);
0325   }
0326 }
0327 
0328 //define this as a plug-in
0329 DEFINE_FWK_MODULE(BoostedJetONNXJetTagsProducer);