File indexing completed on 2025-06-04 22:36:22
0001 #include "FWCore/Framework/interface/Frameworkfwd.h"
0002 #include "FWCore/Framework/interface/stream/EDProducer.h"
0003
0004 #include "FWCore/Framework/interface/Event.h"
0005 #include "FWCore/Framework/interface/MakerMacros.h"
0006
0007 #include "FWCore/Framework/interface/makeRefToBaseProdFrom.h"
0008
0009 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0010 #include "FWCore/Utilities/interface/StreamID.h"
0011
0012 #include "DataFormats/BTauReco/interface/JetTag.h"
0013
0014 #include "DataFormats/BTauReco/interface/DeepFlavourTagInfo.h"
0015
0016 #include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
0017
0018 using namespace cms::Ort;
0019
0020 class DeepFlavourONNXJetTagsProducer : public edm::stream::EDProducer<edm::GlobalCache<ONNXRuntime>> {
0021 public:
0022 explicit DeepFlavourONNXJetTagsProducer(const edm::ParameterSet&, const ONNXRuntime*);
0023 ~DeepFlavourONNXJetTagsProducer() override;
0024
0025 static void fillDescriptions(edm::ConfigurationDescriptions&);
0026
0027 static std::unique_ptr<ONNXRuntime> initializeGlobalCache(const edm::ParameterSet&);
0028 static void globalEndJob(const ONNXRuntime*);
0029
0030 private:
0031 typedef std::vector<reco::DeepFlavourTagInfo> TagInfoCollection;
0032 typedef reco::JetTagCollection JetTagCollection;
0033
0034 void produce(edm::Event&, const edm::EventSetup&) override;
0035
0036 void make_inputs(unsigned i_jet, const reco::DeepFlavourTagInfo& taginfo);
0037
0038 const edm::EDGetTokenT<TagInfoCollection> src_;
0039 edm::EDGetTokenT<edm::View<reco::Jet>> jet_;
0040 std::vector<std::string> flav_names_;
0041 std::vector<std::string> input_names_;
0042 std::vector<std::string> output_names_;
0043
0044 enum InputIndexes { kGlobal = 0, kChargedCandidates = 1, kNeutralCandidates = 2, kVertices = 3, kJetPt = 4 };
0045 constexpr static unsigned n_features_global_ = 15;
0046 constexpr static unsigned n_cpf_ = 25;
0047 constexpr static unsigned n_features_cpf_ = 16;
0048 constexpr static unsigned n_npf_ = 25;
0049 constexpr static unsigned n_features_npf_ = 6;
0050 constexpr static unsigned n_sv_ = 4;
0051 constexpr static unsigned n_features_sv_ = 12;
0052 constexpr static unsigned n_features_jetpt_ = 1;
0053 const static std::vector<unsigned> input_sizes_;
0054
0055
0056 FloatArrays data_;
0057 bool produceValueMap_;
0058 edm::Handle<edm::View<reco::Jet>> jets;
0059 };
0060
0061 const std::vector<unsigned> DeepFlavourONNXJetTagsProducer::input_sizes_{
0062 n_features_global_, n_cpf_* n_features_cpf_, n_npf_* n_features_npf_, n_sv_* n_features_sv_, n_features_jetpt_};
0063
0064 DeepFlavourONNXJetTagsProducer::DeepFlavourONNXJetTagsProducer(const edm::ParameterSet& iConfig,
0065 const ONNXRuntime* cache)
0066 : src_(consumes<TagInfoCollection>(iConfig.getParameter<edm::InputTag>("src"))),
0067 flav_names_(iConfig.getParameter<std::vector<std::string>>("flav_names")),
0068 input_names_(iConfig.getParameter<std::vector<std::string>>("input_names")),
0069 output_names_(iConfig.getParameter<std::vector<std::string>>("output_names")),
0070 produceValueMap_(iConfig.getUntrackedParameter<bool>("produceValueMap", false)) {
0071 if (produceValueMap_) {
0072 jet_ = consumes<edm::View<reco::Jet>>(iConfig.getParameter<edm::InputTag>("jets"));
0073 }
0074
0075
0076 for (const auto& flav_name : flav_names_) {
0077 produces<JetTagCollection>(flav_name);
0078 if (produceValueMap_) {
0079 produces<edm::ValueMap<float>>(flav_name);
0080 }
0081 }
0082
0083 assert(input_names_.size() == input_sizes_.size());
0084 data_.reserve(input_sizes_.size());
0085 }
0086
0087 DeepFlavourONNXJetTagsProducer::~DeepFlavourONNXJetTagsProducer() {}
0088
0089 void DeepFlavourONNXJetTagsProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0090
0091 edm::ParameterSetDescription desc;
0092 desc.add<edm::InputTag>("src", edm::InputTag("pfDeepFlavourTagInfos"));
0093 desc.add<std::vector<std::string>>("input_names", {"input_1", "input_2", "input_3", "input_4", "input_5"});
0094 desc.add<edm::FileInPath>(
0095 "model_path", edm::FileInPath("RecoBTag/Combined/data/DeepFlavourV04_12X_training/DeepJet_Run3_122X.onnx"));
0096 desc.add<std::vector<std::string>>("output_names", {"ID_pred/Softmax:0"});
0097 desc.add<std::vector<std::string>>(
0098 "flav_names", std::vector<std::string>{"probb", "probbb", "problepb", "probc", "probuds", "probg"});
0099 desc.add<edm::InputTag>("jets", edm::InputTag("hltAK4PFPuppiJets"));
0100 desc.addOptionalUntracked<bool>("produceValueMap", false);
0101
0102 descriptions.add("pfDeepFlavourJetTags", desc);
0103 }
0104
0105 std::unique_ptr<ONNXRuntime> DeepFlavourONNXJetTagsProducer::initializeGlobalCache(const edm::ParameterSet& iConfig) {
0106 return std::make_unique<ONNXRuntime>(iConfig.getParameter<edm::FileInPath>("model_path").fullPath());
0107 }
0108
0109 void DeepFlavourONNXJetTagsProducer::globalEndJob(const ONNXRuntime* cache) {}
0110
0111 void DeepFlavourONNXJetTagsProducer::produce(edm::Event& iEvent, const edm::EventSetup& iSetup) {
0112 edm::Handle<TagInfoCollection> tag_infos;
0113 iEvent.getByToken(src_, tag_infos);
0114
0115 if (produceValueMap_) {
0116 iEvent.getByToken(jet_, jets);
0117 if (!jets.isValid()) {
0118 edm::LogWarning("DeepFlavourONNXJetTagsProducer") << "Invalid handle in jet input collection";
0119 return;
0120 }
0121 }
0122
0123 std::vector<std::unique_ptr<JetTagCollection>> output_tags;
0124 std::vector<std::vector<float>> output_scores(flav_names_.size(), std::vector<float>(tag_infos->size(), -1.0));
0125 if (!tag_infos->empty()) {
0126
0127 auto jet_ref = tag_infos->begin()->jet();
0128 auto ref2prod = edm::makeRefToBaseProdFrom(jet_ref, iEvent);
0129 for (std::size_t i = 0; i < flav_names_.size(); i++) {
0130 output_tags.emplace_back(std::make_unique<JetTagCollection>(ref2prod));
0131 }
0132
0133
0134 data_.clear();
0135 for (const auto& len : input_sizes_) {
0136 data_.emplace_back(tag_infos->size() * len, 0.0);
0137 }
0138
0139
0140 for (unsigned jet_n = 0; jet_n < tag_infos->size(); ++jet_n) {
0141 const auto& taginfo = (*tag_infos)[jet_n];
0142 make_inputs(jet_n, taginfo);
0143 }
0144
0145
0146 auto outputs = globalCache()->run(input_names_, data_, {}, output_names_, tag_infos->size())[0];
0147 assert(outputs.size() == flav_names_.size() * tag_infos->size());
0148
0149
0150 unsigned i_output = 0;
0151 for (unsigned jet_n = 0; jet_n < tag_infos->size(); ++jet_n) {
0152 const auto& jet_ref = tag_infos->at(jet_n).jet();
0153 for (std::size_t flav_n = 0; flav_n < flav_names_.size(); flav_n++) {
0154 (*(output_tags[flav_n]))[jet_ref] = outputs[i_output];
0155 if (produceValueMap_) {
0156 output_scores[flav_n][jet_n] = outputs[flav_n];
0157 }
0158 ++i_output;
0159 }
0160 }
0161 } else {
0162
0163 for (std::size_t i = 0; i < flav_names_.size(); i++) {
0164 output_tags.emplace_back(std::make_unique<JetTagCollection>());
0165 }
0166 }
0167
0168
0169 for (std::size_t flav_n = 0; flav_n < flav_names_.size(); ++flav_n) {
0170 if (produceValueMap_) {
0171 for (size_t k = 0; k < output_scores.size(); k++) {
0172 std::unique_ptr<edm::ValueMap<float>> VM(new edm::ValueMap<float>());
0173 edm::ValueMap<float>::Filler filler(*VM);
0174 filler.insert(jets, output_scores.at(k).begin(), output_scores.at(k).end());
0175 filler.fill();
0176 iEvent.put(std::move(VM), flav_names_[k]);
0177 }
0178 }
0179 iEvent.put(std::move(output_tags[flav_n]), flav_names_[flav_n]);
0180 }
0181 data_.clear();
0182 }
0183
0184 void DeepFlavourONNXJetTagsProducer::make_inputs(unsigned i_jet, const reco::DeepFlavourTagInfo& taginfo) {
0185 const auto& features = taginfo.features();
0186 float* ptr = nullptr;
0187 const float* start = nullptr;
0188 unsigned offset = 0;
0189
0190
0191 offset = i_jet * input_sizes_[kGlobal];
0192 ptr = &data_[kGlobal][offset];
0193
0194 const auto& jet_features = features.jet_features;
0195 start = ptr;
0196 *ptr = jet_features.pt;
0197 *(++ptr) = jet_features.eta;
0198
0199 *(++ptr) = features.c_pf_features.size();
0200 *(++ptr) = features.n_pf_features.size();
0201 *(++ptr) = features.sv_features.size();
0202 *(++ptr) = features.npv;
0203
0204 const auto& tag_info_features = features.tag_info_features;
0205 *(++ptr) = tag_info_features.trackSumJetEtRatio;
0206 *(++ptr) = tag_info_features.trackSumJetDeltaR;
0207 *(++ptr) = tag_info_features.vertexCategory;
0208 *(++ptr) = tag_info_features.trackSip2dValAboveCharm;
0209 *(++ptr) = tag_info_features.trackSip2dSigAboveCharm;
0210 *(++ptr) = tag_info_features.trackSip3dValAboveCharm;
0211 *(++ptr) = tag_info_features.trackSip3dSigAboveCharm;
0212 *(++ptr) = tag_info_features.jetNSelectedTracks;
0213 *(++ptr) = tag_info_features.jetNTracksEtaRel;
0214 assert(start + n_features_global_ - 1 == ptr);
0215
0216
0217 auto max_c_pf_n = std::min(features.c_pf_features.size(), (std::size_t)25);
0218 offset = i_jet * input_sizes_[kChargedCandidates];
0219 for (std::size_t c_pf_n = 0; c_pf_n < max_c_pf_n; c_pf_n++) {
0220 const auto& c_pf_features = features.c_pf_features.at(c_pf_n);
0221 ptr = &data_[kChargedCandidates][offset + c_pf_n * n_features_cpf_];
0222 start = ptr;
0223 *ptr = c_pf_features.btagPf_trackEtaRel;
0224 *(++ptr) = c_pf_features.btagPf_trackPtRel;
0225 *(++ptr) = c_pf_features.btagPf_trackPPar;
0226 *(++ptr) = c_pf_features.btagPf_trackDeltaR;
0227 *(++ptr) = c_pf_features.btagPf_trackPParRatio;
0228 *(++ptr) = c_pf_features.btagPf_trackSip2dVal;
0229 *(++ptr) = c_pf_features.btagPf_trackSip2dSig;
0230 *(++ptr) = c_pf_features.btagPf_trackSip3dVal;
0231 *(++ptr) = c_pf_features.btagPf_trackSip3dSig;
0232 *(++ptr) = c_pf_features.btagPf_trackJetDistVal;
0233 *(++ptr) = c_pf_features.ptrel;
0234 *(++ptr) = c_pf_features.drminsv;
0235 *(++ptr) = c_pf_features.vtx_ass;
0236 *(++ptr) = c_pf_features.puppiw;
0237 *(++ptr) = c_pf_features.chi2;
0238 *(++ptr) = c_pf_features.quality;
0239 assert(start + n_features_cpf_ - 1 == ptr);
0240 }
0241
0242
0243 auto max_n_pf_n = std::min(features.n_pf_features.size(), (std::size_t)25);
0244 offset = i_jet * input_sizes_[kNeutralCandidates];
0245 for (std::size_t n_pf_n = 0; n_pf_n < max_n_pf_n; n_pf_n++) {
0246 const auto& n_pf_features = features.n_pf_features.at(n_pf_n);
0247 ptr = &data_[kNeutralCandidates][offset + n_pf_n * n_features_npf_];
0248 start = ptr;
0249 *ptr = n_pf_features.ptrel;
0250 *(++ptr) = n_pf_features.deltaR;
0251 *(++ptr) = n_pf_features.isGamma;
0252 *(++ptr) = n_pf_features.hadFrac;
0253 *(++ptr) = n_pf_features.drminsv;
0254 *(++ptr) = n_pf_features.puppiw;
0255 assert(start + n_features_npf_ - 1 == ptr);
0256 }
0257
0258
0259 auto max_sv_n = std::min(features.sv_features.size(), (std::size_t)4);
0260 offset = i_jet * input_sizes_[kVertices];
0261 for (std::size_t sv_n = 0; sv_n < max_sv_n; sv_n++) {
0262 const auto& sv_features = features.sv_features.at(sv_n);
0263 ptr = &data_[kVertices][offset + sv_n * n_features_sv_];
0264 start = ptr;
0265 *ptr = sv_features.pt;
0266 *(++ptr) = sv_features.deltaR;
0267 *(++ptr) = sv_features.mass;
0268 *(++ptr) = sv_features.ntracks;
0269 *(++ptr) = sv_features.chi2;
0270 *(++ptr) = sv_features.normchi2;
0271 *(++ptr) = sv_features.dxy;
0272 *(++ptr) = sv_features.dxysig;
0273 *(++ptr) = sv_features.d3d;
0274 *(++ptr) = sv_features.d3dsig;
0275 *(++ptr) = sv_features.costhetasvpv;
0276 *(++ptr) = sv_features.enratio;
0277 assert(start + n_features_sv_ - 1 == ptr);
0278 }
0279
0280
0281 offset = i_jet * input_sizes_[kJetPt];
0282 data_[kJetPt][offset] = features.jet_features.pt;
0283 }
0284
0285
0286 DEFINE_FWK_MODULE(DeepFlavourONNXJetTagsProducer);