Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-02-05 03:15:12

0001 #include "FWCore/Framework/interface/Frameworkfwd.h"
0002 #include "FWCore/Framework/interface/stream/EDProducer.h"
0003 
0004 #include "FWCore/Framework/interface/Event.h"
0005 #include "FWCore/Framework/interface/MakerMacros.h"
0006 
0007 #include "FWCore/Framework/interface/makeRefToBaseProdFrom.h"
0008 
0009 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0010 #include "FWCore/Utilities/interface/StreamID.h"
0011 
0012 #include "DataFormats/BTauReco/interface/JetTag.h"
0013 
0014 #include "DataFormats/BTauReco/interface/UnifiedParticleTransformerAK4TagInfo.h"
0015 #include "DataFormats/BTauReco/interface/UnifiedParticleTransformerAK4Features.h"
0016 
0017 #include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
0018 
0019 using namespace cms::Ort;
0020 
0021 class UnifiedParticleTransformerAK4ONNXJetTagsProducer : public edm::stream::EDProducer<edm::GlobalCache<ONNXRuntime>> {
0022 public:
0023   explicit UnifiedParticleTransformerAK4ONNXJetTagsProducer(const edm::ParameterSet&, const ONNXRuntime*);
0024   ~UnifiedParticleTransformerAK4ONNXJetTagsProducer() override = default;
0025 
0026   static void fillDescriptions(edm::ConfigurationDescriptions&);
0027 
0028   static std::unique_ptr<ONNXRuntime> initializeGlobalCache(const edm::ParameterSet&);
0029   static void globalEndJob(const ONNXRuntime*);
0030 
0031 private:
0032   typedef std::vector<reco::UnifiedParticleTransformerAK4TagInfo> TagInfoCollection;
0033   typedef reco::JetTagCollection JetTagCollection;
0034 
0035   void produce(edm::Event&, const edm::EventSetup&) override;
0036 
0037   void make_inputs(btagbtvdeep::UnifiedParticleTransformerAK4Features features);
0038   void get_input_sizes(const reco::FeaturesTagInfo<btagbtvdeep::UnifiedParticleTransformerAK4Features> taginfo);
0039 
0040   const edm::EDGetTokenT<TagInfoCollection> src_;
0041   std::vector<std::string> flav_names_;
0042   std::vector<std::string> input_names_;
0043   bool use_dynamic_axes_ = false;
0044   std::vector<std::string> output_names_;
0045 
0046   enum InputIndexes {
0047     kChargedCandidates = 0,
0048     kLostTracks = 1,
0049     kNeutralCandidates = 2,
0050     kVertices = 3,
0051     kChargedCandidates4Vec = 4,
0052     kLostTracks4Vec = 5,
0053     kNeutralCandidates4Vec = 6,
0054     kVertices4Vec = 7
0055   };
0056   unsigned n_cpf_;
0057   constexpr static unsigned n_features_cpf_ = 25;
0058   constexpr static unsigned n_pairwise_features_cpf_ = 4;
0059   unsigned n_lt_;
0060   constexpr static unsigned n_features_lt_ = 18;
0061   constexpr static unsigned n_pairwise_features_lt_ = 4;
0062   unsigned n_npf_;
0063   constexpr static unsigned n_features_npf_ = 8;
0064   constexpr static unsigned n_pairwise_features_npf_ = 4;
0065   unsigned n_sv_;
0066   constexpr static unsigned n_features_sv_ = 14;
0067   constexpr static unsigned n_pairwise_features_sv_ = 4;
0068   std::vector<unsigned> input_sizes_;
0069   std::vector<std::vector<int64_t>> input_shapes_;  // shapes of each input group (-1 for dynamic axis)
0070 
0071   // hold the input data
0072   FloatArrays data_;
0073 };
0074 
0075 UnifiedParticleTransformerAK4ONNXJetTagsProducer::UnifiedParticleTransformerAK4ONNXJetTagsProducer(
0076     const edm::ParameterSet& iConfig, const ONNXRuntime* cache)
0077     : src_(consumes<TagInfoCollection>(iConfig.getParameter<edm::InputTag>("src"))),
0078       flav_names_(iConfig.getParameter<std::vector<std::string>>("flav_names")),
0079       input_names_(iConfig.getParameter<std::vector<std::string>>("input_names")),
0080       use_dynamic_axes_(iConfig.getParameter<edm::FileInPath>("model_path").fullPath().find("v2.onnx") !=
0081                         std::string::npos),
0082       output_names_(iConfig.getParameter<std::vector<std::string>>("output_names")) {
0083   // get output names from flav_names
0084   for (const auto& flav_name : flav_names_) {
0085     produces<JetTagCollection>(flav_name);
0086   }
0087 }
0088 
0089 void UnifiedParticleTransformerAK4ONNXJetTagsProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0090   // pfUnifiedParticleTransformerAK4JetTags
0091   edm::ParameterSetDescription desc;
0092   desc.add<edm::InputTag>("src", edm::InputTag("pfUnifiedParticleTransformerAK4TagInfos"));
0093   desc.add<std::vector<std::string>>(
0094       "input_names", {"input_1", "input_2", "input_3", "input_4", "input_5", "input_6", "input_7", "input_8"});
0095   desc.add<edm::FileInPath>("model_path",
0096                             edm::FileInPath("RecoBTag/Combined/data/UParTAK4/PUPPI/V01/UParTAK4_v2.onnx"));
0097   desc.add<std::vector<std::string>>("output_names", {"softmax"});
0098   desc.add<std::vector<std::string>>(
0099       "flav_names",
0100       std::vector<std::string>{"probb",        "probbb",       "problepb",     "probc",         "probs",
0101                                "probu",        "probd",        "probg",        "probele",       "probmu",
0102                                "probtaup1h0p", "probtaup1h1p", "probtaup1h2p", "probtaup3h0p",  "probtaup3h1p",
0103                                "probtaum1h0p", "probtaum1h1p", "probtaum1h2p", "probtaum3h0p",  "probtaum3h1p",
0104                                "ptcorr",       "ptreshigh",    "ptreslow",     "ptnu",          "probemudata",
0105                                "probemumc",    "probdimudata", "probdimumc",   "probmutaudata", "probmutaumc"});
0106 
0107   descriptions.add("pfUnifiedParticleTransformerAK4JetTags", desc);
0108 }
0109 
0110 std::unique_ptr<ONNXRuntime> UnifiedParticleTransformerAK4ONNXJetTagsProducer::initializeGlobalCache(
0111     const edm::ParameterSet& iConfig) {
0112   return std::make_unique<ONNXRuntime>(iConfig.getParameter<edm::FileInPath>("model_path").fullPath());
0113 }
0114 
0115 void UnifiedParticleTransformerAK4ONNXJetTagsProducer::globalEndJob(const ONNXRuntime* cache) {}
0116 
0117 void UnifiedParticleTransformerAK4ONNXJetTagsProducer::produce(edm::Event& iEvent, const edm::EventSetup& iSetup) {
0118   edm::Handle<TagInfoCollection> tag_infos;
0119   iEvent.getByToken(src_, tag_infos);
0120 
0121   // initialize output collection
0122   std::vector<std::unique_ptr<JetTagCollection>> output_tags;
0123   if (!tag_infos->empty()) {
0124     auto jet_ref = tag_infos->begin()->jet();
0125     auto ref2prod = edm::makeRefToBaseProdFrom(jet_ref, iEvent);
0126     for (std::size_t i = 0; i < flav_names_.size(); i++) {
0127       output_tags.emplace_back(std::make_unique<JetTagCollection>(ref2prod));
0128     }
0129   } else {
0130     for (std::size_t i = 0; i < flav_names_.size(); i++) {
0131       output_tags.emplace_back(std::make_unique<JetTagCollection>());
0132     }
0133   }
0134 
0135   for (unsigned jet_n = 0; jet_n < tag_infos->size(); ++jet_n) {
0136     const auto& taginfo = (*tag_infos)[jet_n];
0137     std::vector<float> outputs(flav_names_.size(), -1.0);
0138     if (taginfo.features().is_filled) {
0139       get_input_sizes(taginfo);
0140 
0141       // run prediction with dynamic batch size per event
0142       input_shapes_ = {{(int64_t)1, (int64_t)n_cpf_, (int64_t)n_features_cpf_},
0143                        {(int64_t)1, (int64_t)n_lt_, (int64_t)n_features_lt_},
0144                        {(int64_t)1, (int64_t)n_npf_, (int64_t)n_features_npf_},
0145                        {(int64_t)1, (int64_t)n_sv_, (int64_t)n_features_sv_},
0146                        {(int64_t)1, (int64_t)n_cpf_, (int64_t)n_pairwise_features_cpf_},
0147                        {(int64_t)1, (int64_t)n_lt_, (int64_t)n_pairwise_features_lt_},
0148                        {(int64_t)1, (int64_t)n_npf_, (int64_t)n_pairwise_features_npf_},
0149                        {(int64_t)1, (int64_t)n_sv_, (int64_t)n_pairwise_features_sv_}};
0150 
0151       outputs = globalCache()->run(input_names_, data_, input_shapes_, output_names_, 1)[0];
0152       assert(outputs.size() == flav_names_.size());
0153     }
0154 
0155     const auto& jet_ref = tag_infos->at(jet_n).jet();
0156     for (std::size_t flav_n = 0; flav_n < flav_names_.size(); flav_n++) {
0157       (*(output_tags[flav_n]))[jet_ref] = outputs[flav_n];
0158     }
0159   }
0160 
0161   // put into the event
0162   for (std::size_t flav_n = 0; flav_n < flav_names_.size(); ++flav_n) {
0163     iEvent.put(std::move(output_tags[flav_n]), flav_names_[flav_n]);
0164   }
0165 }
0166 
0167 void UnifiedParticleTransformerAK4ONNXJetTagsProducer::get_input_sizes(
0168     const reco::FeaturesTagInfo<btagbtvdeep::UnifiedParticleTransformerAK4Features> taginfo) {
0169   const auto& features = taginfo.features();
0170 
0171   if (use_dynamic_axes_) {
0172     // Use actual sizes for dynamic axes version
0173     n_cpf_ = std::clamp((unsigned int)features.c_pf_features.size(), (unsigned int)1, (unsigned int)29);
0174     n_lt_ = std::clamp((unsigned int)features.lt_features.size(), (unsigned int)1, (unsigned int)5);
0175     n_npf_ = std::clamp((unsigned int)features.n_pf_features.size(), (unsigned int)1, (unsigned int)25);
0176     n_sv_ = std::clamp((unsigned int)features.sv_features.size(), (unsigned int)1, (unsigned int)5);
0177 
0178   } else {
0179     // Use fixed sizes for original version
0180     n_cpf_ = (unsigned int)29;
0181     n_lt_ = (unsigned int)5;
0182     n_npf_ = (unsigned int)25;
0183     n_sv_ = (unsigned int)5;
0184   }
0185 
0186   input_sizes_ = {
0187       n_cpf_ * n_features_cpf_,
0188       n_lt_ * n_features_lt_,
0189       n_npf_ * n_features_npf_,
0190       n_sv_ * n_features_sv_,
0191       n_cpf_ * n_pairwise_features_cpf_,
0192       n_lt_ * n_pairwise_features_lt_,
0193       n_npf_ * n_pairwise_features_npf_,
0194       n_sv_ * n_pairwise_features_sv_,
0195   };
0196   // init data storage
0197   data_.clear();
0198   for (const auto& len : input_sizes_) {
0199     data_.emplace_back(1 * len, 0);
0200   }
0201 
0202   make_inputs(features);
0203 }
0204 
0205 void UnifiedParticleTransformerAK4ONNXJetTagsProducer::make_inputs(
0206     btagbtvdeep::UnifiedParticleTransformerAK4Features features) {
0207   float* ptr = nullptr;
0208   const float* start = nullptr;
0209   unsigned offset = 0;
0210 
0211   // c_pf candidates
0212   auto max_c_pf_n = std::min(features.c_pf_features.size(), (std::size_t)n_cpf_);
0213   for (std::size_t c_pf_n = 0; c_pf_n < max_c_pf_n; c_pf_n++) {
0214     const auto& c_pf_features = features.c_pf_features.at(c_pf_n);
0215     ptr = &data_[kChargedCandidates][offset + c_pf_n * n_features_cpf_];
0216     start = ptr;
0217     *ptr = c_pf_features.btagPf_trackEtaRel;
0218     *(++ptr) = c_pf_features.btagPf_trackPtRel;
0219     *(++ptr) = c_pf_features.btagPf_trackPPar;
0220     *(++ptr) = c_pf_features.btagPf_trackDeltaR;
0221     *(++ptr) = c_pf_features.btagPf_trackPParRatio;
0222     *(++ptr) = c_pf_features.btagPf_trackSip2dVal;
0223     *(++ptr) = c_pf_features.btagPf_trackSip2dSig;
0224     *(++ptr) = c_pf_features.btagPf_trackSip3dVal;
0225     *(++ptr) = c_pf_features.btagPf_trackSip3dSig;
0226     *(++ptr) = c_pf_features.btagPf_trackJetDistVal;
0227     *(++ptr) = c_pf_features.ptrel;
0228     *(++ptr) = c_pf_features.drminsv;
0229     *(++ptr) = c_pf_features.vtx_ass;
0230     *(++ptr) = c_pf_features.puppiw;
0231     *(++ptr) = c_pf_features.chi2;
0232     *(++ptr) = c_pf_features.quality;
0233     *(++ptr) = c_pf_features.charge;
0234     *(++ptr) = c_pf_features.dz;
0235     *(++ptr) = c_pf_features.btagPf_trackDecayLen;
0236     *(++ptr) = c_pf_features.HadFrac;
0237     *(++ptr) = c_pf_features.CaloFrac;
0238     *(++ptr) = c_pf_features.pdgID;
0239     *(++ptr) = c_pf_features.lostInnerHits;
0240     *(++ptr) = c_pf_features.numberOfPixelHits;
0241     *(++ptr) = c_pf_features.numberOfStripHits;
0242 
0243     assert(start + n_features_cpf_ - 1 == ptr);
0244   }
0245 
0246   // n_lt candidates
0247   auto max_lt_n = std::min(features.lt_features.size(), (std::size_t)n_lt_);
0248   for (std::size_t lt_n = 0; lt_n < max_lt_n; lt_n++) {
0249     const auto& lt_features = features.lt_features.at(lt_n);
0250     ptr = &data_[kLostTracks][offset + lt_n * n_features_lt_];
0251     start = ptr;
0252     *ptr = lt_features.btagPf_trackEtaRel;
0253     *(++ptr) = lt_features.btagPf_trackPtRel;
0254     *(++ptr) = lt_features.btagPf_trackPPar;
0255     *(++ptr) = lt_features.btagPf_trackDeltaR;
0256     *(++ptr) = lt_features.btagPf_trackPParRatio;
0257     *(++ptr) = lt_features.btagPf_trackSip2dVal;
0258     *(++ptr) = lt_features.btagPf_trackSip2dSig;
0259     *(++ptr) = lt_features.btagPf_trackSip3dVal;
0260     *(++ptr) = lt_features.btagPf_trackSip3dSig;
0261     *(++ptr) = lt_features.btagPf_trackJetDistVal;
0262     *(++ptr) = lt_features.drminsv;
0263     *(++ptr) = lt_features.charge;
0264     *(++ptr) = lt_features.puppiw;
0265     *(++ptr) = lt_features.chi2;
0266     *(++ptr) = lt_features.quality;
0267     *(++ptr) = lt_features.lostInnerHits;
0268     *(++ptr) = lt_features.numberOfPixelHits;
0269     *(++ptr) = lt_features.numberOfStripHits;
0270     assert(start + n_features_lt_ - 1 == ptr);
0271   }
0272 
0273   // n_pf candidates
0274   auto max_n_pf_n = std::min(features.n_pf_features.size(), (std::size_t)n_npf_);
0275   for (std::size_t n_pf_n = 0; n_pf_n < max_n_pf_n; n_pf_n++) {
0276     const auto& n_pf_features = features.n_pf_features.at(n_pf_n);
0277     ptr = &data_[kNeutralCandidates][offset + n_pf_n * n_features_npf_];
0278     start = ptr;
0279     *ptr = n_pf_features.ptrel;
0280     *(++ptr) = n_pf_features.etarel;
0281     *(++ptr) = n_pf_features.phirel;
0282     *(++ptr) = n_pf_features.deltaR;
0283     *(++ptr) = n_pf_features.isGamma;
0284     *(++ptr) = n_pf_features.hadFrac;
0285     *(++ptr) = n_pf_features.drminsv;
0286     *(++ptr) = n_pf_features.puppiw;
0287     assert(start + n_features_npf_ - 1 == ptr);
0288   }
0289 
0290   // sv candidates
0291   auto max_sv_n = std::min(features.sv_features.size(), (std::size_t)n_sv_);
0292   for (std::size_t sv_n = 0; sv_n < max_sv_n; sv_n++) {
0293     const auto& sv_features = features.sv_features.at(sv_n);
0294     ptr = &data_[kVertices][offset + sv_n * n_features_sv_];
0295     start = ptr;
0296     *ptr = sv_features.pt;
0297     *(++ptr) = sv_features.deltaR;
0298     *(++ptr) = sv_features.mass;
0299     *(++ptr) = sv_features.etarel;
0300     *(++ptr) = sv_features.phirel;
0301     *(++ptr) = sv_features.ntracks;
0302     *(++ptr) = sv_features.chi2;
0303     *(++ptr) = sv_features.normchi2;
0304     *(++ptr) = sv_features.dxy;
0305     *(++ptr) = sv_features.dxysig;
0306     *(++ptr) = sv_features.d3d;
0307     *(++ptr) = sv_features.d3dsig;
0308     *(++ptr) = sv_features.costhetasvpv;
0309     *(++ptr) = sv_features.enratio;
0310     assert(start + n_features_sv_ - 1 == ptr);
0311   }
0312 
0313   // cpf pairwise features (4-vectors)
0314   auto max_cpf_n = std::min(features.c_pf_features.size(), (std::size_t)n_cpf_);
0315   for (std::size_t cpf_n = 0; cpf_n < max_cpf_n; cpf_n++) {
0316     const auto& cpf_pairwise_features = features.c_pf_features.at(cpf_n);
0317     ptr = &data_[kChargedCandidates4Vec][offset + cpf_n * n_pairwise_features_cpf_];
0318     start = ptr;
0319     *ptr = cpf_pairwise_features.px;
0320     *(++ptr) = cpf_pairwise_features.py;
0321     *(++ptr) = cpf_pairwise_features.pz;
0322     *(++ptr) = cpf_pairwise_features.e;
0323 
0324     assert(start + n_pairwise_features_cpf_ - 1 == ptr);
0325   }
0326 
0327   // lt pairwise features (4-vectors) specific case requiring (pt,eta,phi,e)
0328   auto max_lt_N = std::min(features.lt_features.size(), (std::size_t)n_lt_);
0329   for (std::size_t lt_N = 0; lt_N < max_lt_N; lt_N++) {
0330     const auto& lt_pairwise_features = features.lt_features.at(lt_N);
0331     ptr = &data_[kLostTracks4Vec][offset + lt_N * n_pairwise_features_lt_];
0332     start = ptr;
0333     *ptr = lt_pairwise_features.pt;
0334     *(++ptr) = lt_pairwise_features.eta;
0335     *(++ptr) = lt_pairwise_features.phi;
0336     *(++ptr) = lt_pairwise_features.e;
0337 
0338     assert(start + n_pairwise_features_lt_ - 1 == ptr);
0339   }
0340 
0341   // npf pairwise features (4-vectors)
0342   auto max_npf_n = std::min(features.n_pf_features.size(), (std::size_t)n_npf_);
0343   for (std::size_t npf_n = 0; npf_n < max_npf_n; npf_n++) {
0344     const auto& npf_pairwise_features = features.n_pf_features.at(npf_n);
0345     ptr = &data_[kNeutralCandidates4Vec][offset + npf_n * n_pairwise_features_npf_];
0346     start = ptr;
0347     *ptr = npf_pairwise_features.px;
0348     *(++ptr) = npf_pairwise_features.py;
0349     *(++ptr) = npf_pairwise_features.pz;
0350     *(++ptr) = npf_pairwise_features.e;
0351 
0352     assert(start + n_pairwise_features_npf_ - 1 == ptr);
0353   }
0354 
0355   // sv pairwise features (4-vectors)
0356   auto max_sv_N = std::min(features.sv_features.size(), (std::size_t)n_sv_);
0357   for (std::size_t sv_N = 0; sv_N < max_sv_N; sv_N++) {
0358     const auto& sv_pairwise_features = features.sv_features.at(sv_N);
0359     ptr = &data_[kVertices4Vec][offset + sv_N * n_pairwise_features_sv_];
0360     start = ptr;
0361     *ptr = sv_pairwise_features.px;
0362     *(++ptr) = sv_pairwise_features.py;
0363     *(++ptr) = sv_pairwise_features.pz;
0364     *(++ptr) = sv_pairwise_features.e;
0365 
0366     assert(start + n_pairwise_features_sv_ - 1 == ptr);
0367   }
0368 }
0369 
0370 //define this as a plug-in
0371 DEFINE_FWK_MODULE(UnifiedParticleTransformerAK4ONNXJetTagsProducer);