Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-02-04 03:48:10

0001 /*
0002  * \class DeepTauId
0003  *
0004  * Tau identification using Deep NN.
0005  *
0006  * \author Konstantin Androsov, INFN Pisa
0007  *         Christian Veelken, Tallinn
0008  */
0009 
0010 #include "RecoTauTag/RecoTau/interface/DeepTauBase.h"
0011 #include "RecoTauTag/RecoTau/interface/DeepTauScaling.h"
0012 #include "FWCore/Utilities/interface/isFinite.h"
0013 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0014 #include "DataFormats/TauReco/interface/PFTauTransverseImpactParameterAssociation.h"
0015 
0016 #include <fstream>
0017 #include "oneapi/tbb/concurrent_unordered_set.h"
0018 
0019 namespace deep_tau {
0020   constexpr int NumberOfOutputs = 4;
0021 }
0022 
0023 namespace {
0024 
0025   namespace dnn_inputs_v2 {
0026     constexpr int number_of_inner_cell = 11;
0027     constexpr int number_of_outer_cell = 21;
0028     constexpr int number_of_conv_features = 64;
0029     namespace TauBlockInputs {
0030       enum vars {
0031         rho = 0,
0032         tau_pt,
0033         tau_eta,
0034         tau_phi,
0035         tau_mass,
0036         tau_E_over_pt,
0037         tau_charge,
0038         tau_n_charged_prongs,
0039         tau_n_neutral_prongs,
0040         chargedIsoPtSum,
0041         chargedIsoPtSumdR03_over_dR05,
0042         footprintCorrection,
0043         neutralIsoPtSum,
0044         neutralIsoPtSumWeight_over_neutralIsoPtSum,
0045         neutralIsoPtSumWeightdR03_over_neutralIsoPtSum,
0046         neutralIsoPtSumdR03_over_dR05,
0047         photonPtSumOutsideSignalCone,
0048         puCorrPtSum,
0049         tau_dxy_pca_x,
0050         tau_dxy_pca_y,
0051         tau_dxy_pca_z,
0052         tau_dxy_valid,
0053         tau_dxy,
0054         tau_dxy_sig,
0055         tau_ip3d_valid,
0056         tau_ip3d,
0057         tau_ip3d_sig,
0058         tau_dz,
0059         tau_dz_sig_valid,
0060         tau_dz_sig,
0061         tau_flightLength_x,
0062         tau_flightLength_y,
0063         tau_flightLength_z,
0064         tau_flightLength_sig,
0065         tau_pt_weighted_deta_strip,
0066         tau_pt_weighted_dphi_strip,
0067         tau_pt_weighted_dr_signal,
0068         tau_pt_weighted_dr_iso,
0069         tau_leadingTrackNormChi2,
0070         tau_e_ratio_valid,
0071         tau_e_ratio,
0072         tau_gj_angle_diff_valid,
0073         tau_gj_angle_diff,
0074         tau_n_photons,
0075         tau_emFraction,
0076         tau_inside_ecal_crack,
0077         leadChargedCand_etaAtEcalEntrance_minus_tau_eta,
0078         NumberOfInputs
0079       };
0080       std::vector<int> varsToDrop = {
0081           tau_phi, tau_dxy_pca_x, tau_dxy_pca_y, tau_dxy_pca_z};  // indices of vars to be dropped in the full var enum
0082     }                                                             // namespace TauBlockInputs
0083 
0084     namespace EgammaBlockInputs {
0085       enum vars {
0086         rho = 0,
0087         tau_pt,
0088         tau_eta,
0089         tau_inside_ecal_crack,
0090         pfCand_ele_valid,
0091         pfCand_ele_rel_pt,
0092         pfCand_ele_deta,
0093         pfCand_ele_dphi,
0094         pfCand_ele_pvAssociationQuality,
0095         pfCand_ele_puppiWeight,
0096         pfCand_ele_charge,
0097         pfCand_ele_lostInnerHits,
0098         pfCand_ele_numberOfPixelHits,
0099         pfCand_ele_vertex_dx,
0100         pfCand_ele_vertex_dy,
0101         pfCand_ele_vertex_dz,
0102         pfCand_ele_vertex_dx_tauFL,
0103         pfCand_ele_vertex_dy_tauFL,
0104         pfCand_ele_vertex_dz_tauFL,
0105         pfCand_ele_hasTrackDetails,
0106         pfCand_ele_dxy,
0107         pfCand_ele_dxy_sig,
0108         pfCand_ele_dz,
0109         pfCand_ele_dz_sig,
0110         pfCand_ele_track_chi2_ndof,
0111         pfCand_ele_track_ndof,
0112         ele_valid,
0113         ele_rel_pt,
0114         ele_deta,
0115         ele_dphi,
0116         ele_cc_valid,
0117         ele_cc_ele_rel_energy,
0118         ele_cc_gamma_rel_energy,
0119         ele_cc_n_gamma,
0120         ele_rel_trackMomentumAtVtx,
0121         ele_rel_trackMomentumAtCalo,
0122         ele_rel_trackMomentumOut,
0123         ele_rel_trackMomentumAtEleClus,
0124         ele_rel_trackMomentumAtVtxWithConstraint,
0125         ele_rel_ecalEnergy,
0126         ele_ecalEnergy_sig,
0127         ele_eSuperClusterOverP,
0128         ele_eSeedClusterOverP,
0129         ele_eSeedClusterOverPout,
0130         ele_eEleClusterOverPout,
0131         ele_deltaEtaSuperClusterTrackAtVtx,
0132         ele_deltaEtaSeedClusterTrackAtCalo,
0133         ele_deltaEtaEleClusterTrackAtCalo,
0134         ele_deltaPhiEleClusterTrackAtCalo,
0135         ele_deltaPhiSuperClusterTrackAtVtx,
0136         ele_deltaPhiSeedClusterTrackAtCalo,
0137         ele_mvaInput_earlyBrem,
0138         ele_mvaInput_lateBrem,
0139         ele_mvaInput_sigmaEtaEta,
0140         ele_mvaInput_hadEnergy,
0141         ele_mvaInput_deltaEta,
0142         ele_gsfTrack_normalizedChi2,
0143         ele_gsfTrack_numberOfValidHits,
0144         ele_rel_gsfTrack_pt,
0145         ele_gsfTrack_pt_sig,
0146         ele_has_closestCtfTrack,
0147         ele_closestCtfTrack_normalizedChi2,
0148         ele_closestCtfTrack_numberOfValidHits,
0149         pfCand_gamma_valid,
0150         pfCand_gamma_rel_pt,
0151         pfCand_gamma_deta,
0152         pfCand_gamma_dphi,
0153         pfCand_gamma_pvAssociationQuality,
0154         pfCand_gamma_fromPV,
0155         pfCand_gamma_puppiWeight,
0156         pfCand_gamma_puppiWeightNoLep,
0157         pfCand_gamma_lostInnerHits,
0158         pfCand_gamma_numberOfPixelHits,
0159         pfCand_gamma_vertex_dx,
0160         pfCand_gamma_vertex_dy,
0161         pfCand_gamma_vertex_dz,
0162         pfCand_gamma_vertex_dx_tauFL,
0163         pfCand_gamma_vertex_dy_tauFL,
0164         pfCand_gamma_vertex_dz_tauFL,
0165         pfCand_gamma_hasTrackDetails,
0166         pfCand_gamma_dxy,
0167         pfCand_gamma_dxy_sig,
0168         pfCand_gamma_dz,
0169         pfCand_gamma_dz_sig,
0170         pfCand_gamma_track_chi2_ndof,
0171         pfCand_gamma_track_ndof,
0172         NumberOfInputs
0173       };
0174     }
0175 
0176     namespace MuonBlockInputs {
0177       enum vars {
0178         rho = 0,
0179         tau_pt,
0180         tau_eta,
0181         tau_inside_ecal_crack,
0182         pfCand_muon_valid,
0183         pfCand_muon_rel_pt,
0184         pfCand_muon_deta,
0185         pfCand_muon_dphi,
0186         pfCand_muon_pvAssociationQuality,
0187         pfCand_muon_fromPV,
0188         pfCand_muon_puppiWeight,
0189         pfCand_muon_charge,
0190         pfCand_muon_lostInnerHits,
0191         pfCand_muon_numberOfPixelHits,
0192         pfCand_muon_vertex_dx,
0193         pfCand_muon_vertex_dy,
0194         pfCand_muon_vertex_dz,
0195         pfCand_muon_vertex_dx_tauFL,
0196         pfCand_muon_vertex_dy_tauFL,
0197         pfCand_muon_vertex_dz_tauFL,
0198         pfCand_muon_hasTrackDetails,
0199         pfCand_muon_dxy,
0200         pfCand_muon_dxy_sig,
0201         pfCand_muon_dz,
0202         pfCand_muon_dz_sig,
0203         pfCand_muon_track_chi2_ndof,
0204         pfCand_muon_track_ndof,
0205         muon_valid,
0206         muon_rel_pt,
0207         muon_deta,
0208         muon_dphi,
0209         muon_dxy,
0210         muon_dxy_sig,
0211         muon_normalizedChi2_valid,
0212         muon_normalizedChi2,
0213         muon_numberOfValidHits,
0214         muon_segmentCompatibility,
0215         muon_caloCompatibility,
0216         muon_pfEcalEnergy_valid,
0217         muon_rel_pfEcalEnergy,
0218         muon_n_matches_DT_1,
0219         muon_n_matches_DT_2,
0220         muon_n_matches_DT_3,
0221         muon_n_matches_DT_4,
0222         muon_n_matches_CSC_1,
0223         muon_n_matches_CSC_2,
0224         muon_n_matches_CSC_3,
0225         muon_n_matches_CSC_4,
0226         muon_n_matches_RPC_1,
0227         muon_n_matches_RPC_2,
0228         muon_n_matches_RPC_3,
0229         muon_n_matches_RPC_4,
0230         muon_n_hits_DT_1,
0231         muon_n_hits_DT_2,
0232         muon_n_hits_DT_3,
0233         muon_n_hits_DT_4,
0234         muon_n_hits_CSC_1,
0235         muon_n_hits_CSC_2,
0236         muon_n_hits_CSC_3,
0237         muon_n_hits_CSC_4,
0238         muon_n_hits_RPC_1,
0239         muon_n_hits_RPC_2,
0240         muon_n_hits_RPC_3,
0241         muon_n_hits_RPC_4,
0242         NumberOfInputs
0243       };
0244     }
0245 
0246     namespace HadronBlockInputs {
0247       enum vars {
0248         rho = 0,
0249         tau_pt,
0250         tau_eta,
0251         tau_inside_ecal_crack,
0252         pfCand_chHad_valid,
0253         pfCand_chHad_rel_pt,
0254         pfCand_chHad_deta,
0255         pfCand_chHad_dphi,
0256         pfCand_chHad_leadChargedHadrCand,
0257         pfCand_chHad_pvAssociationQuality,
0258         pfCand_chHad_fromPV,
0259         pfCand_chHad_puppiWeight,
0260         pfCand_chHad_puppiWeightNoLep,
0261         pfCand_chHad_charge,
0262         pfCand_chHad_lostInnerHits,
0263         pfCand_chHad_numberOfPixelHits,
0264         pfCand_chHad_vertex_dx,
0265         pfCand_chHad_vertex_dy,
0266         pfCand_chHad_vertex_dz,
0267         pfCand_chHad_vertex_dx_tauFL,
0268         pfCand_chHad_vertex_dy_tauFL,
0269         pfCand_chHad_vertex_dz_tauFL,
0270         pfCand_chHad_hasTrackDetails,
0271         pfCand_chHad_dxy,
0272         pfCand_chHad_dxy_sig,
0273         pfCand_chHad_dz,
0274         pfCand_chHad_dz_sig,
0275         pfCand_chHad_track_chi2_ndof,
0276         pfCand_chHad_track_ndof,
0277         pfCand_chHad_hcalFraction,
0278         pfCand_chHad_rawCaloFraction,
0279         pfCand_nHad_valid,
0280         pfCand_nHad_rel_pt,
0281         pfCand_nHad_deta,
0282         pfCand_nHad_dphi,
0283         pfCand_nHad_puppiWeight,
0284         pfCand_nHad_puppiWeightNoLep,
0285         pfCand_nHad_hcalFraction,
0286         NumberOfInputs
0287       };
0288     }
0289   }  // namespace dnn_inputs_v2
0290 
0291   float getTauID(const pat::Tau& tau, const std::string& tauID, float default_value = -999., bool assert_input = true) {
0292     static tbb::concurrent_unordered_set<std::string> isFirstWarning;
0293     if (tau.isTauIDAvailable(tauID)) {
0294       return tau.tauID(tauID);
0295     } else {
0296       if (assert_input) {
0297         throw cms::Exception("DeepTauId")
0298             << "Exception in <getTauID>: No tauID '" << tauID << "' available in pat::Tau given as function argument.";
0299       }
0300       if (isFirstWarning.insert(tauID).second) {
0301         edm::LogWarning("DeepTauID") << "Warning in <getTauID>: No tauID '" << tauID
0302                                      << "' available in pat::Tau given as function argument."
0303                                      << " Using default_value = " << default_value << " instead." << std::endl;
0304       }
0305       return default_value;
0306     }
0307   }
0308 
0309   struct TauFunc {
0310     const reco::TauDiscriminatorContainer* basicTauDiscriminatorCollection;
0311     const reco::TauDiscriminatorContainer* basicTauDiscriminatordR03Collection;
0312     const edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>*
0313         pfTauTransverseImpactParameters;
0314 
0315     using BasicDiscr = deep_tau::DeepTauBase::BasicDiscriminator;
0316     std::map<BasicDiscr, size_t> indexMap;
0317     std::map<BasicDiscr, size_t> indexMapdR03;
0318 
0319     const float getChargedIsoPtSum(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0320       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::ChargedIsoPtSum));
0321     }
0322     const float getChargedIsoPtSum(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0323       return getTauID(tau, "chargedIsoPtSum");
0324     }
0325     const float getChargedIsoPtSumdR03(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0326       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(indexMapdR03.at(BasicDiscr::ChargedIsoPtSum));
0327     }
0328     const float getChargedIsoPtSumdR03(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0329       return getTauID(tau, "chargedIsoPtSumdR03");
0330     }
0331     const float getFootprintCorrectiondR03(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0332       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(
0333           indexMapdR03.at(BasicDiscr::FootprintCorrection));
0334     }
0335     const float getFootprintCorrectiondR03(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0336       return getTauID(tau, "footprintCorrectiondR03");
0337     }
0338     const float getFootprintCorrection(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0339       return getTauID(tau, "footprintCorrection");
0340     }
0341     const float getNeutralIsoPtSum(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0342       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::NeutralIsoPtSum));
0343     }
0344     const float getNeutralIsoPtSum(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0345       return getTauID(tau, "neutralIsoPtSum");
0346     }
0347     const float getNeutralIsoPtSumdR03(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0348       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(indexMapdR03.at(BasicDiscr::NeutralIsoPtSum));
0349     }
0350     const float getNeutralIsoPtSumdR03(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0351       return getTauID(tau, "neutralIsoPtSumdR03");
0352     }
0353     const float getNeutralIsoPtSumWeight(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0354       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::NeutralIsoPtSumWeight));
0355     }
0356     const float getNeutralIsoPtSumWeight(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0357       return getTauID(tau, "neutralIsoPtSumWeight");
0358     }
0359     const float getNeutralIsoPtSumdR03Weight(const reco::PFTau& tau,
0360                                              const edm::RefToBase<reco::BaseTau> tau_ref) const {
0361       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(
0362           indexMapdR03.at(BasicDiscr::NeutralIsoPtSumWeight));
0363     }
0364     const float getNeutralIsoPtSumdR03Weight(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0365       return getTauID(tau, "neutralIsoPtSumWeightdR03");
0366     }
0367     const float getPhotonPtSumOutsideSignalCone(const reco::PFTau& tau,
0368                                                 const edm::RefToBase<reco::BaseTau> tau_ref) const {
0369       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(
0370           indexMap.at(BasicDiscr::PhotonPtSumOutsideSignalCone));
0371     }
0372     const float getPhotonPtSumOutsideSignalCone(const pat::Tau& tau,
0373                                                 const edm::RefToBase<reco::BaseTau> tau_ref) const {
0374       return getTauID(tau, "photonPtSumOutsideSignalCone");
0375     }
0376     const float getPhotonPtSumOutsideSignalConedR03(const reco::PFTau& tau,
0377                                                     const edm::RefToBase<reco::BaseTau> tau_ref) const {
0378       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(
0379           indexMapdR03.at(BasicDiscr::PhotonPtSumOutsideSignalCone));
0380     }
0381     const float getPhotonPtSumOutsideSignalConedR03(const pat::Tau& tau,
0382                                                     const edm::RefToBase<reco::BaseTau> tau_ref) const {
0383       return getTauID(tau, "photonPtSumOutsideSignalConedR03");
0384     }
0385     const float getPuCorrPtSum(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0386       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::PUcorrPtSum));
0387     }
0388     const float getPuCorrPtSum(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0389       return getTauID(tau, "puCorrPtSum");
0390     }
0391 
0392     auto getdxyPCA(const reco::PFTau& tau, const size_t tau_index) const {
0393       return pfTauTransverseImpactParameters->value(tau_index)->dxy_PCA();
0394     }
0395     auto getdxyPCA(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy_PCA(); }
0396     auto getdxy(const reco::PFTau& tau, const size_t tau_index) const {
0397       return pfTauTransverseImpactParameters->value(tau_index)->dxy();
0398     }
0399     auto getdxy(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy(); }
0400     auto getdxyError(const reco::PFTau& tau, const size_t tau_index) const {
0401       return pfTauTransverseImpactParameters->value(tau_index)->dxy_error();
0402     }
0403     auto getdxyError(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy_error(); }
0404     auto getdxySig(const reco::PFTau& tau, const size_t tau_index) const {
0405       return pfTauTransverseImpactParameters->value(tau_index)->dxy_Sig();
0406     }
0407     auto getdxySig(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy_Sig(); }
0408     auto getip3d(const reco::PFTau& tau, const size_t tau_index) const {
0409       return pfTauTransverseImpactParameters->value(tau_index)->ip3d();
0410     }
0411     auto getip3d(const pat::Tau& tau, const size_t tau_index) const { return tau.ip3d(); }
0412     auto getip3dError(const reco::PFTau& tau, const size_t tau_index) const {
0413       return pfTauTransverseImpactParameters->value(tau_index)->ip3d_error();
0414     }
0415     auto getip3dError(const pat::Tau& tau, const size_t tau_index) const { return tau.ip3d_error(); }
0416     auto getip3dSig(const reco::PFTau& tau, const size_t tau_index) const {
0417       return pfTauTransverseImpactParameters->value(tau_index)->ip3d_Sig();
0418     }
0419     auto getip3dSig(const pat::Tau& tau, const size_t tau_index) const { return tau.ip3d_Sig(); }
0420     auto getHasSecondaryVertex(const reco::PFTau& tau, const size_t tau_index) const {
0421       return pfTauTransverseImpactParameters->value(tau_index)->hasSecondaryVertex();
0422     }
0423     auto getHasSecondaryVertex(const pat::Tau& tau, const size_t tau_index) const { return tau.hasSecondaryVertex(); }
0424     auto getFlightLength(const reco::PFTau& tau, const size_t tau_index) const {
0425       return pfTauTransverseImpactParameters->value(tau_index)->flightLength();
0426     }
0427     auto getFlightLength(const pat::Tau& tau, const size_t tau_index) const { return tau.flightLength(); }
0428     auto getFlightLengthSig(const reco::PFTau& tau, const size_t tau_index) const {
0429       return pfTauTransverseImpactParameters->value(tau_index)->flightLengthSig();
0430     }
0431     auto getFlightLengthSig(const pat::Tau& tau, const size_t tau_index) const { return tau.flightLengthSig(); }
0432 
0433     auto getLeadingTrackNormChi2(const reco::PFTau& tau) { return reco::tau::lead_track_chi2(tau); }
0434     auto getLeadingTrackNormChi2(const pat::Tau& tau) { return tau.leadingTrackNormChi2(); }
0435     auto getEmFraction(const pat::Tau& tau) { return tau.emFraction_MVA(); }
0436     auto getEmFraction(const reco::PFTau& tau) { return tau.emFraction(); }
0437     auto getEtaAtEcalEntrance(const pat::Tau& tau) { return tau.etaAtEcalEntranceLeadChargedCand(); }
0438     auto getEtaAtEcalEntrance(const reco::PFTau& tau) {
0439       return tau.leadPFChargedHadrCand()->positionAtECALEntrance().eta();
0440     }
0441     auto getEcalEnergyLeadingChargedHadr(const reco::PFTau& tau) { return tau.leadPFChargedHadrCand()->ecalEnergy(); }
0442     auto getEcalEnergyLeadingChargedHadr(const pat::Tau& tau) { return tau.ecalEnergyLeadChargedHadrCand(); }
0443     auto getHcalEnergyLeadingChargedHadr(const reco::PFTau& tau) { return tau.leadPFChargedHadrCand()->hcalEnergy(); }
0444     auto getHcalEnergyLeadingChargedHadr(const pat::Tau& tau) { return tau.hcalEnergyLeadChargedHadrCand(); }
0445 
0446     template <typename PreDiscrType>
0447     bool passPrediscriminants(const PreDiscrType prediscriminants,
0448                               const size_t andPrediscriminants,
0449                               const edm::RefToBase<reco::BaseTau> tau_ref) {
0450       bool passesPrediscriminants = (andPrediscriminants ? 1 : 0);
0451       // check tau passes prediscriminants
0452       size_t nPrediscriminants = prediscriminants.size();
0453       for (size_t iDisc = 0; iDisc < nPrediscriminants; ++iDisc) {
0454         // current discriminant result for this tau
0455         double discResult = (*prediscriminants[iDisc].handle)[tau_ref];
0456         uint8_t thisPasses = (discResult > prediscriminants[iDisc].cut) ? 1 : 0;
0457 
0458         // if we are using the AND option, as soon as one fails,
0459         // the result is FAIL and we can quit looping.
0460         // if we are using the OR option as soon as one passes,
0461         // the result is pass and we can quit looping
0462 
0463         // truth table
0464         //        |   result (thisPasses)
0465         //        |     F     |     T
0466         //-----------------------------------
0467         // AND(T) | res=fails |  continue
0468         //        |  break    |
0469         //-----------------------------------
0470         // OR (F) |  continue | res=passes
0471         //        |           |  break
0472 
0473         if (thisPasses ^ andPrediscriminants)  //XOR
0474         {
0475           passesPrediscriminants = (andPrediscriminants ? 0 : 1);  //NOR
0476           break;
0477         }
0478       }
0479       return passesPrediscriminants;
0480     }
0481   };
0482 
0483   namespace candFunc {
0484     auto getTauDz(const reco::PFCandidate& cand) { return cand.bestTrack()->dz(); }
0485     auto getTauDz(const pat::PackedCandidate& cand) { return cand.dz(); }
0486     auto getTauDZSigValid(const reco::PFCandidate& cand) {
0487       return cand.bestTrack() != nullptr && std::isnormal(cand.bestTrack()->dz()) && std::isnormal(cand.dzError()) &&
0488              cand.dzError() > 0;
0489     }
0490     auto getTauDZSigValid(const pat::PackedCandidate& cand) {
0491       return cand.hasTrackDetails() && std::isnormal(cand.dz()) && std::isnormal(cand.dzError()) && cand.dzError() > 0;
0492     }
0493     auto getTauDxy(const reco::PFCandidate& cand) { return cand.bestTrack()->dxy(); }
0494     auto getTauDxy(const pat::PackedCandidate& cand) { return cand.dxy(); }
0495     auto getPvAssocationQuality(const reco::PFCandidate& cand) { return 0.7013f; }
0496     auto getPvAssocationQuality(const pat::PackedCandidate& cand) { return cand.pvAssociationQuality(); }
0497     auto getPuppiWeight(const reco::PFCandidate& cand, const float aod_value) { return aod_value; }
0498     auto getPuppiWeight(const pat::PackedCandidate& cand, const float aod_value) { return cand.puppiWeight(); }
0499     auto getPuppiWeightNoLep(const reco::PFCandidate& cand, const float aod_value) { return aod_value; }
0500     auto getPuppiWeightNoLep(const pat::PackedCandidate& cand, const float aod_value) {
0501       return cand.puppiWeightNoLep();
0502     }
0503     auto getLostInnerHits(const reco::PFCandidate& cand, float default_value) {
0504       return cand.bestTrack() != nullptr
0505                  ? cand.bestTrack()->hitPattern().numberOfLostHits(reco::HitPattern::MISSING_INNER_HITS)
0506                  : default_value;
0507     }
0508     auto getLostInnerHits(const pat::PackedCandidate& cand, float default_value) { return cand.lostInnerHits(); }
0509     auto getNumberOfPixelHits(const reco::PFCandidate& cand, float default_value) {
0510       return cand.bestTrack() != nullptr
0511                  ? cand.bestTrack()->hitPattern().numberOfLostHits(reco::HitPattern::MISSING_INNER_HITS)
0512                  : default_value;
0513     }
0514     auto getNumberOfPixelHits(const pat::PackedCandidate& cand, float default_value) {
0515       return cand.numberOfPixelHits();
0516     }
0517     auto getHasTrackDetails(const reco::PFCandidate& cand) { return cand.bestTrack() != nullptr; }
0518     auto getHasTrackDetails(const pat::PackedCandidate& cand) { return cand.hasTrackDetails(); }
0519     auto getPseudoTrack(const reco::PFCandidate& cand) { return *cand.bestTrack(); }
0520     auto getPseudoTrack(const pat::PackedCandidate& cand) { return cand.pseudoTrack(); }
0521     auto getFromPV(const reco::PFCandidate& cand) { return 0.9994f; }
0522     auto getFromPV(const pat::PackedCandidate& cand) { return cand.fromPV(); }
0523     auto getHCalFraction(const reco::PFCandidate& cand, bool disable_hcalFraction_workaround) {
0524       return cand.rawHcalEnergy() / (cand.rawHcalEnergy() + cand.rawEcalEnergy());
0525     }
0526     auto getHCalFraction(const pat::PackedCandidate& cand, bool disable_hcalFraction_workaround) {
0527       float hcal_fraction = 0.;
0528       if (disable_hcalFraction_workaround) {
0529         // CV: use consistent definition for pfCand_chHad_hcalFraction
0530         //     in DeepTauId.cc code and in TauMLTools/Production/plugins/TauTupleProducer.cc
0531         hcal_fraction = cand.hcalFraction();
0532       } else {
0533         // CV: backwards compatibility with DeepTau training v2p1 used during Run 2
0534         if (cand.pdgId() == 1 || cand.pdgId() == 130) {
0535           hcal_fraction = cand.hcalFraction();
0536         } else if (cand.isIsolatedChargedHadron()) {
0537           hcal_fraction = cand.rawHcalFraction();
0538         }
0539       }
0540       return hcal_fraction;
0541     }
0542     auto getRawCaloFraction(const reco::PFCandidate& cand) {
0543       return (cand.rawEcalEnergy() + cand.rawHcalEnergy()) / cand.energy();
0544     }
0545     auto getRawCaloFraction(const pat::PackedCandidate& cand) { return cand.rawCaloFraction(); }
0546   };  // namespace candFunc
0547 
0548   template <typename LVector1, typename LVector2>
0549   float dEta(const LVector1& p4, const LVector2& tau_p4) {
0550     return static_cast<float>(p4.eta() - tau_p4.eta());
0551   }
0552 
0553   template <typename LVector1, typename LVector2>
0554   float dPhi(const LVector1& p4_1, const LVector2& p4_2) {
0555     return static_cast<float>(reco::deltaPhi(p4_2.phi(), p4_1.phi()));
0556   }
0557 
0558   struct MuonHitMatchV2 {
0559     static constexpr size_t n_muon_stations = 4;
0560     static constexpr int first_station_id = 1;
0561     static constexpr int last_station_id = first_station_id + n_muon_stations - 1;
0562     using CountArray = std::array<unsigned, n_muon_stations>;
0563     using CountMap = std::map<int, CountArray>;
0564 
0565     const std::vector<int>& consideredSubdets() {
0566       static const std::vector<int> subdets = {MuonSubdetId::DT, MuonSubdetId::CSC, MuonSubdetId::RPC};
0567       return subdets;
0568     }
0569 
0570     const std::string& subdetName(int subdet) {
0571       static const std::map<int, std::string> subdet_names = {
0572           {MuonSubdetId::DT, "DT"}, {MuonSubdetId::CSC, "CSC"}, {MuonSubdetId::RPC, "RPC"}};
0573       if (!subdet_names.count(subdet))
0574         throw cms::Exception("MuonHitMatch") << "Subdet name for subdet id " << subdet << " not found.";
0575       return subdet_names.at(subdet);
0576     }
0577 
0578     size_t getStationIndex(int station, bool throw_exception) const {
0579       if (station < first_station_id || station > last_station_id) {
0580         if (throw_exception)
0581           throw cms::Exception("MuonHitMatch") << "Station id is out of range";
0582         return std::numeric_limits<size_t>::max();
0583       }
0584       return static_cast<size_t>(station - 1);
0585     }
0586 
0587     MuonHitMatchV2(const pat::Muon& muon) {
0588       for (int subdet : consideredSubdets()) {
0589         n_matches[subdet].fill(0);
0590         n_hits[subdet].fill(0);
0591       }
0592 
0593       countMatches(muon, n_matches);
0594       countHits(muon, n_hits);
0595     }
0596 
0597     void countMatches(const pat::Muon& muon, CountMap& n_matches) {
0598       for (const auto& segment : muon.matches()) {
0599         if (segment.segmentMatches.empty() && segment.rpcMatches.empty())
0600           continue;
0601         if (n_matches.count(segment.detector())) {
0602           const size_t station_index = getStationIndex(segment.station(), true);
0603           ++n_matches.at(segment.detector()).at(station_index);
0604         }
0605       }
0606     }
0607 
0608     void countHits(const pat::Muon& muon, CountMap& n_hits) {
0609       if (muon.outerTrack().isNonnull()) {
0610         const auto& hit_pattern = muon.outerTrack()->hitPattern();
0611         for (int hit_index = 0; hit_index < hit_pattern.numberOfAllHits(reco::HitPattern::TRACK_HITS); ++hit_index) {
0612           auto hit_id = hit_pattern.getHitPattern(reco::HitPattern::TRACK_HITS, hit_index);
0613           if (hit_id == 0)
0614             break;
0615           if (hit_pattern.muonHitFilter(hit_id) && (hit_pattern.getHitType(hit_id) == TrackingRecHit::valid ||
0616                                                     hit_pattern.getHitType(hit_id) == TrackingRecHit::bad)) {
0617             const size_t station_index = getStationIndex(hit_pattern.getMuonStation(hit_id), false);
0618             if (station_index < n_muon_stations) {
0619               CountArray* muon_n_hits = nullptr;
0620               if (hit_pattern.muonDTHitFilter(hit_id))
0621                 muon_n_hits = &n_hits.at(MuonSubdetId::DT);
0622               else if (hit_pattern.muonCSCHitFilter(hit_id))
0623                 muon_n_hits = &n_hits.at(MuonSubdetId::CSC);
0624               else if (hit_pattern.muonRPCHitFilter(hit_id))
0625                 muon_n_hits = &n_hits.at(MuonSubdetId::RPC);
0626 
0627               if (muon_n_hits)
0628                 ++muon_n_hits->at(station_index);
0629             }
0630           }
0631         }
0632       }
0633     }
0634 
0635     unsigned nMatches(int subdet, int station) const {
0636       if (!n_matches.count(subdet))
0637         throw cms::Exception("MuonHitMatch") << "Subdet " << subdet << " not found.";
0638       const size_t station_index = getStationIndex(station, true);
0639       return n_matches.at(subdet).at(station_index);
0640     }
0641 
0642     unsigned nHits(int subdet, int station) const {
0643       if (!n_hits.count(subdet))
0644         throw cms::Exception("MuonHitMatch") << "Subdet " << subdet << " not found.";
0645       const size_t station_index = getStationIndex(station, true);
0646       return n_hits.at(subdet).at(station_index);
0647     }
0648 
0649     unsigned countMuonStationsWithMatches(int first_station, int last_station) const {
0650       static const std::map<int, std::vector<bool>> masks = {
0651           {MuonSubdetId::DT, {false, false, false, false}},
0652           {MuonSubdetId::CSC, {true, false, false, false}},
0653           {MuonSubdetId::RPC, {false, false, false, false}},
0654       };
0655       const size_t first_station_index = getStationIndex(first_station, true);
0656       const size_t last_station_index = getStationIndex(last_station, true);
0657       unsigned cnt = 0;
0658       for (size_t n = first_station_index; n <= last_station_index; ++n) {
0659         for (const auto& match : n_matches) {
0660           if (!masks.at(match.first).at(n) && match.second.at(n) > 0)
0661             ++cnt;
0662         }
0663       }
0664       return cnt;
0665     }
0666 
0667     unsigned countMuonStationsWithHits(int first_station, int last_station) const {
0668       static const std::map<int, std::vector<bool>> masks = {
0669           {MuonSubdetId::DT, {false, false, false, false}},
0670           {MuonSubdetId::CSC, {false, false, false, false}},
0671           {MuonSubdetId::RPC, {false, false, false, false}},
0672       };
0673 
0674       const size_t first_station_index = getStationIndex(first_station, true);
0675       const size_t last_station_index = getStationIndex(last_station, true);
0676       unsigned cnt = 0;
0677       for (size_t n = first_station_index; n <= last_station_index; ++n) {
0678         for (const auto& hit : n_hits) {
0679           if (!masks.at(hit.first).at(n) && hit.second.at(n) > 0)
0680             ++cnt;
0681         }
0682       }
0683       return cnt;
0684     }
0685 
0686   private:
0687     CountMap n_matches, n_hits;
0688   };
0689 
0690   enum class CellObjectType {
0691     PfCand_electron,
0692     PfCand_muon,
0693     PfCand_chargedHadron,
0694     PfCand_neutralHadron,
0695     PfCand_gamma,
0696     Electron,
0697     Muon,
0698     Other
0699   };
0700 
0701   template <typename Object>
0702   CellObjectType GetCellObjectType(const Object&);
0703   template <>
0704   CellObjectType GetCellObjectType(const pat::Electron&) {
0705     return CellObjectType::Electron;
0706   }
0707   template <>
0708   CellObjectType GetCellObjectType(const pat::Muon&) {
0709     return CellObjectType::Muon;
0710   }
0711 
0712   template <>
0713   CellObjectType GetCellObjectType(reco::Candidate const& cand) {
0714     static const std::map<int, CellObjectType> obj_types = {{11, CellObjectType::PfCand_electron},
0715                                                             {13, CellObjectType::PfCand_muon},
0716                                                             {22, CellObjectType::PfCand_gamma},
0717                                                             {130, CellObjectType::PfCand_neutralHadron},
0718                                                             {211, CellObjectType::PfCand_chargedHadron}};
0719 
0720     auto iter = obj_types.find(std::abs(cand.pdgId()));
0721     if (iter == obj_types.end())
0722       return CellObjectType::Other;
0723     return iter->second;
0724   }
0725 
0726   using Cell = std::map<CellObjectType, size_t>;
0727   struct CellIndex {
0728     int eta, phi;
0729 
0730     bool operator<(const CellIndex& other) const {
0731       if (eta != other.eta)
0732         return eta < other.eta;
0733       return phi < other.phi;
0734     }
0735   };
0736 
0737   class CellGrid {
0738   public:
0739     using Map = std::map<CellIndex, Cell>;
0740     using const_iterator = Map::const_iterator;
0741 
0742     CellGrid(unsigned n_cells_eta,
0743              unsigned n_cells_phi,
0744              double cell_size_eta,
0745              double cell_size_phi,
0746              bool disable_CellIndex_workaround)
0747         : nCellsEta(n_cells_eta),
0748           nCellsPhi(n_cells_phi),
0749           nTotal(nCellsEta * nCellsPhi),
0750           cellSizeEta(cell_size_eta),
0751           cellSizePhi(cell_size_phi),
0752           disable_CellIndex_workaround_(disable_CellIndex_workaround) {
0753       if (nCellsEta % 2 != 1 || nCellsEta < 1)
0754         throw cms::Exception("DeepTauId") << "Invalid number of eta cells.";
0755       if (nCellsPhi % 2 != 1 || nCellsPhi < 1)
0756         throw cms::Exception("DeepTauId") << "Invalid number of phi cells.";
0757       if (cellSizeEta <= 0 || cellSizePhi <= 0)
0758         throw cms::Exception("DeepTauId") << "Invalid cell size.";
0759     }
0760 
0761     int maxEtaIndex() const { return static_cast<int>((nCellsEta - 1) / 2); }
0762     int maxPhiIndex() const { return static_cast<int>((nCellsPhi - 1) / 2); }
0763     double maxDeltaEta() const { return cellSizeEta * (0.5 + maxEtaIndex()); }
0764     double maxDeltaPhi() const { return cellSizePhi * (0.5 + maxPhiIndex()); }
0765     int getEtaTensorIndex(const CellIndex& cellIndex) const { return cellIndex.eta + maxEtaIndex(); }
0766     int getPhiTensorIndex(const CellIndex& cellIndex) const { return cellIndex.phi + maxPhiIndex(); }
0767 
0768     bool tryGetCellIndex(double deltaEta, double deltaPhi, CellIndex& cellIndex) const {
0769       const auto getCellIndex = [this](double x, double maxX, double size, int& index) {
0770         const double absX = std::abs(x);
0771         if (absX > maxX)
0772           return false;
0773         double absIndex;
0774         if (disable_CellIndex_workaround_) {
0775           // CV: use consistent definition for CellIndex
0776           //     in DeepTauId.cc code and new DeepTau trainings
0777           absIndex = std::floor(absX / size + 0.5);
0778         } else {
0779           // CV: backwards compatibility with DeepTau training v2p1 used during Run 2
0780           absIndex = std::floor(std::abs(absX / size - 0.5));
0781         }
0782         index = static_cast<int>(std::copysign(absIndex, x));
0783         return true;
0784       };
0785 
0786       return getCellIndex(deltaEta, maxDeltaEta(), cellSizeEta, cellIndex.eta) &&
0787              getCellIndex(deltaPhi, maxDeltaPhi(), cellSizePhi, cellIndex.phi);
0788     }
0789 
0790     size_t num_valid_cells() const { return cells.size(); }
0791     Cell& operator[](const CellIndex& cellIndex) { return cells[cellIndex]; }
0792     const Cell& at(const CellIndex& cellIndex) const { return cells.at(cellIndex); }
0793     size_t count(const CellIndex& cellIndex) const { return cells.count(cellIndex); }
0794     const_iterator find(const CellIndex& cellIndex) const { return cells.find(cellIndex); }
0795     const_iterator begin() const { return cells.begin(); }
0796     const_iterator end() const { return cells.end(); }
0797 
0798   public:
0799     const unsigned nCellsEta, nCellsPhi, nTotal;
0800     const double cellSizeEta, cellSizePhi;
0801 
0802   private:
0803     std::map<CellIndex, Cell> cells;
0804     const bool disable_CellIndex_workaround_;
0805   };
0806 }  // anonymous namespace
0807 
0808 using bd = deep_tau::DeepTauBase::BasicDiscriminator;
0809 const std::map<bd, std::string> deep_tau::DeepTauBase::stringFromDiscriminator_{
0810     {bd::ChargedIsoPtSum, "ChargedIsoPtSum"},
0811     {bd::NeutralIsoPtSum, "NeutralIsoPtSum"},
0812     {bd::NeutralIsoPtSumWeight, "NeutralIsoPtSumWeight"},
0813     {bd::FootprintCorrection, "TauFootprintCorrection"},
0814     {bd::PhotonPtSumOutsideSignalCone, "PhotonPtSumOutsideSignalCone"},
0815     {bd::PUcorrPtSum, "PUcorrPtSum"}};
0816 const std::vector<bd> deep_tau::DeepTauBase::requiredBasicDiscriminators_ = {bd::ChargedIsoPtSum,
0817                                                                              bd::NeutralIsoPtSum,
0818                                                                              bd::NeutralIsoPtSumWeight,
0819                                                                              bd::PhotonPtSumOutsideSignalCone,
0820                                                                              bd::PUcorrPtSum};
0821 const std::vector<bd> deep_tau::DeepTauBase::requiredBasicDiscriminatorsdR03_ = {bd::ChargedIsoPtSum,
0822                                                                                  bd::NeutralIsoPtSum,
0823                                                                                  bd::NeutralIsoPtSumWeight,
0824                                                                                  bd::PhotonPtSumOutsideSignalCone,
0825                                                                                  bd::FootprintCorrection};
0826 
0827 class DeepTauId : public deep_tau::DeepTauBase {
0828 public:
0829   static constexpr float default_value = -999.;
0830 
0831   static const OutputCollection& GetOutputs() {
0832     static constexpr size_t e_index = 0, mu_index = 1, tau_index = 2, jet_index = 3;
0833     static const OutputCollection outputs_ = {
0834         {"VSe", Output({tau_index}, {e_index, tau_index})},
0835         {"VSmu", Output({tau_index}, {mu_index, tau_index})},
0836         {"VSjet", Output({tau_index}, {jet_index, tau_index})},
0837     };
0838     return outputs_;
0839   }
0840 
0841   const std::map<BasicDiscriminator, size_t> matchDiscriminatorIndices(
0842       edm::Event& event,
0843       edm::EDGetTokenT<reco::TauDiscriminatorContainer> discriminatorContainerToken,
0844       std::vector<BasicDiscriminator> requiredDiscr) {
0845     std::map<std::string, size_t> discrIndexMapStr;
0846     auto const aHandle = event.getHandle(discriminatorContainerToken);
0847     auto const aProv = aHandle.provenance();
0848     if (aProv == nullptr)
0849       aHandle.whyFailed()->raise();
0850     const auto& psetsFromProvenance = edm::parameterSet(aProv->stable(), event.processHistory());
0851     auto const idlist = psetsFromProvenance.getParameter<std::vector<edm::ParameterSet>>("IDdefinitions");
0852     for (size_t j = 0; j < idlist.size(); ++j) {
0853       std::string idname = idlist[j].getParameter<std::string>("IDname");
0854       if (discrIndexMapStr.count(idname)) {
0855         throw cms::Exception("DeepTauId")
0856             << "basic discriminator " << idname << " appears more than once in the input.";
0857       }
0858       discrIndexMapStr[idname] = j;
0859     }
0860 
0861     //translate to a map of <BasicDiscriminator, index> and check if all discriminators are present
0862     std::map<BasicDiscriminator, size_t> discrIndexMap;
0863     for (size_t i = 0; i < requiredDiscr.size(); i++) {
0864       if (discrIndexMapStr.find(stringFromDiscriminator_.at(requiredDiscr[i])) == discrIndexMapStr.end())
0865         throw cms::Exception("DeepTauId") << "Basic Discriminator " << stringFromDiscriminator_.at(requiredDiscr[i])
0866                                           << " was not provided in the config file.";
0867       else
0868         discrIndexMap[requiredDiscr[i]] = discrIndexMapStr[stringFromDiscriminator_.at(requiredDiscr[i])];
0869     }
0870     return discrIndexMap;
0871   }
0872 
0873   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0874     edm::ParameterSetDescription desc;
0875     desc.add<edm::InputTag>("electrons", edm::InputTag("slimmedElectrons"));
0876     desc.add<edm::InputTag>("muons", edm::InputTag("slimmedMuons"));
0877     desc.add<edm::InputTag>("taus", edm::InputTag("slimmedTaus"));
0878     desc.add<edm::InputTag>("pfcands", edm::InputTag("packedPFCandidates"));
0879     desc.add<edm::InputTag>("vertices", edm::InputTag("offlineSlimmedPrimaryVertices"));
0880     desc.add<edm::InputTag>("rho", edm::InputTag("fixedGridRhoAll"));
0881     desc.add<std::vector<std::string>>("graph_file",
0882                                        {"RecoTauTag/TrainingFiles/data/DeepTauId/deepTau_2017v2p6_e6.pb"});
0883     desc.add<bool>("mem_mapped", false);
0884     desc.add<unsigned>("year", 2017);
0885     desc.add<unsigned>("version", 2);
0886     desc.add<unsigned>("sub_version", 1);
0887     desc.add<int>("debug_level", 0);
0888     desc.add<bool>("disable_dxy_pca", false);
0889     desc.add<bool>("disable_hcalFraction_workaround", false);
0890     desc.add<bool>("disable_CellIndex_workaround", false);
0891     desc.add<bool>("save_inputs", false);
0892     desc.add<bool>("is_online", false);
0893 
0894     desc.add<std::vector<std::string>>("VSeWP", {"-1."});
0895     desc.add<std::vector<std::string>>("VSmuWP", {"-1."});
0896     desc.add<std::vector<std::string>>("VSjetWP", {"-1."});
0897 
0898     desc.addUntracked<edm::InputTag>("basicTauDiscriminators", edm::InputTag("basicTauDiscriminators"));
0899     desc.addUntracked<edm::InputTag>("basicTauDiscriminatorsdR03", edm::InputTag("basicTauDiscriminatorsdR03"));
0900     desc.add<edm::InputTag>("pfTauTransverseImpactParameters", edm::InputTag("hpsPFTauTransverseImpactParameters"));
0901 
0902     {
0903       edm::ParameterSetDescription pset_Prediscriminants;
0904       pset_Prediscriminants.add<std::string>("BooleanOperator", "and");
0905       {
0906         edm::ParameterSetDescription psd1;
0907         psd1.add<double>("cut");
0908         psd1.add<edm::InputTag>("Producer");
0909         pset_Prediscriminants.addOptional<edm::ParameterSetDescription>("decayMode", psd1);
0910       }
0911       desc.add<edm::ParameterSetDescription>("Prediscriminants", pset_Prediscriminants);
0912     }
0913 
0914     descriptions.add("DeepTau", desc);
0915   }
0916 
0917 public:
0918   explicit DeepTauId(const edm::ParameterSet& cfg, const deep_tau::DeepTauCache* cache)
0919       : DeepTauBase(cfg, GetOutputs(), cache),
0920         electrons_token_(consumes<std::vector<pat::Electron>>(cfg.getParameter<edm::InputTag>("electrons"))),
0921         muons_token_(consumes<std::vector<pat::Muon>>(cfg.getParameter<edm::InputTag>("muons"))),
0922         rho_token_(consumes<double>(cfg.getParameter<edm::InputTag>("rho"))),
0923         basicTauDiscriminators_inputToken_(consumes<reco::TauDiscriminatorContainer>(
0924             cfg.getUntrackedParameter<edm::InputTag>("basicTauDiscriminators"))),
0925         basicTauDiscriminatorsdR03_inputToken_(consumes<reco::TauDiscriminatorContainer>(
0926             cfg.getUntrackedParameter<edm::InputTag>("basicTauDiscriminatorsdR03"))),
0927         pfTauTransverseImpactParameters_token_(
0928             consumes<edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>>(
0929                 cfg.getParameter<edm::InputTag>("pfTauTransverseImpactParameters"))),
0930         year_(cfg.getParameter<unsigned>("year")),
0931         version_(cfg.getParameter<unsigned>("version")),
0932         sub_version_(cfg.getParameter<unsigned>("sub_version")),
0933         debug_level(cfg.getParameter<int>("debug_level")),
0934         disable_dxy_pca_(cfg.getParameter<bool>("disable_dxy_pca")),
0935         disable_hcalFraction_workaround_(cfg.getParameter<bool>("disable_hcalFraction_workaround")),
0936         disable_CellIndex_workaround_(cfg.getParameter<bool>("disable_CellIndex_workaround")),
0937         save_inputs_(cfg.getParameter<bool>("save_inputs")),
0938         json_file_(nullptr),
0939         file_counter_(0) {
0940     if (version_ == 2) {
0941       using namespace dnn_inputs_v2;
0942       namespace sc = deep_tau::Scaling;
0943       tauInputs_indices_.resize(TauBlockInputs::NumberOfInputs);
0944       std::iota(std::begin(tauInputs_indices_), std::end(tauInputs_indices_), 0);
0945 
0946       if (sub_version_ == 1) {
0947         tauBlockTensor_ = std::make_unique<tensorflow::Tensor>(
0948             tensorflow::DT_FLOAT, tensorflow::TensorShape{1, TauBlockInputs::NumberOfInputs});
0949         scalingParamsMap_ = &sc::scalingParamsMap_v2p1;
0950       } else if (sub_version_ == 5) {
0951         std::sort(TauBlockInputs::varsToDrop.begin(), TauBlockInputs::varsToDrop.end());
0952         for (auto v : TauBlockInputs::varsToDrop) {
0953           tauInputs_indices_.at(v) = -1;  // set index to -1
0954           for (std::size_t i = v + 1; i < TauBlockInputs::NumberOfInputs; ++i)
0955             tauInputs_indices_.at(i) -= 1;  // shift all the following indices by 1
0956         }
0957         tauBlockTensor_ = std::make_unique<tensorflow::Tensor>(
0958             tensorflow::DT_FLOAT,
0959             tensorflow::TensorShape{1,
0960                                     static_cast<int>(TauBlockInputs::NumberOfInputs) -
0961                                         static_cast<int>(TauBlockInputs::varsToDrop.size())});
0962         if (year_ == 2026) {
0963           scalingParamsMap_ = &sc::scalingParamsMap_PhaseIIv2p5;
0964         } else {
0965           scalingParamsMap_ = &sc::scalingParamsMap_v2p5;
0966         }
0967       } else
0968         throw cms::Exception("DeepTauId") << "subversion " << sub_version_ << " is not supported.";
0969 
0970       std::map<std::vector<bool>, std::vector<sc::FeatureT>> GridFeatureTypes_map = {
0971           {{false}, {sc::FeatureT::TauFlat, sc::FeatureT::GridGlobal}},  // feature types without inner/outer grid split
0972           {{false, true},
0973            {sc::FeatureT::PfCand_electron,
0974             sc::FeatureT::PfCand_muon,  // feature types with inner/outer grid split
0975             sc::FeatureT::PfCand_chHad,
0976             sc::FeatureT::PfCand_nHad,
0977             sc::FeatureT::PfCand_gamma,
0978             sc::FeatureT::Electron,
0979             sc::FeatureT::Muon}}};
0980 
0981       // check that sizes of mean/std/lim_min/lim_max vectors are equal between each other
0982       for (const auto& p : GridFeatureTypes_map) {
0983         for (auto is_inner : p.first) {
0984           for (auto featureType : p.second) {
0985             const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(featureType, is_inner));
0986             if (!(sp.mean_.size() == sp.std_.size() && sp.mean_.size() == sp.lim_min_.size() &&
0987                   sp.mean_.size() == sp.lim_max_.size()))
0988               throw cms::Exception("DeepTauId") << "sizes of scaling parameter vectors do not match between each other";
0989           }
0990         }
0991       }
0992 
0993       for (size_t n = 0; n < 2; ++n) {
0994         const bool is_inner = n == 0;
0995         const auto n_cells = is_inner ? number_of_inner_cell : number_of_outer_cell;
0996         eGammaTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
0997             tensorflow::DT_FLOAT, tensorflow::TensorShape{1, 1, 1, EgammaBlockInputs::NumberOfInputs});
0998         muonTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
0999             tensorflow::DT_FLOAT, tensorflow::TensorShape{1, 1, 1, MuonBlockInputs::NumberOfInputs});
1000         hadronsTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
1001             tensorflow::DT_FLOAT, tensorflow::TensorShape{1, 1, 1, HadronBlockInputs::NumberOfInputs});
1002         convTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
1003             tensorflow::DT_FLOAT, tensorflow::TensorShape{1, n_cells, n_cells, number_of_conv_features});
1004         zeroOutputTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
1005             tensorflow::DT_FLOAT, tensorflow::TensorShape{1, 1, 1, number_of_conv_features});
1006 
1007         eGammaTensor_[is_inner]->flat<float>().setZero();
1008         muonTensor_[is_inner]->flat<float>().setZero();
1009         hadronsTensor_[is_inner]->flat<float>().setZero();
1010 
1011         setCellConvFeatures(*zeroOutputTensor_[is_inner], getPartialPredictions(is_inner), 0, 0, 0);
1012       }
1013     } else {
1014       throw cms::Exception("DeepTauId") << "version " << version_ << " is not supported.";
1015     }
1016   }
1017 
1018   static std::unique_ptr<deep_tau::DeepTauCache> initializeGlobalCache(const edm::ParameterSet& cfg) {
1019     return DeepTauBase::initializeGlobalCache(cfg);
1020   }
1021 
1022   static void globalEndJob(const deep_tau::DeepTauCache* cache_) { return DeepTauBase::globalEndJob(cache_); }
1023 
1024 private:
1025   static constexpr float pi = M_PI;
1026 
1027   template <typename T>
1028   static float getValue(T value) {
1029     return std::isnormal(value) ? static_cast<float>(value) : 0.f;
1030   }
1031 
1032   template <typename T>
1033   static float getValueLinear(T value, float min_value, float max_value, bool positive) {
1034     const float fixed_value = getValue(value);
1035     const float clamped_value = std::clamp(fixed_value, min_value, max_value);
1036     float transformed_value = (clamped_value - min_value) / (max_value - min_value);
1037     if (!positive)
1038       transformed_value = transformed_value * 2 - 1;
1039     return transformed_value;
1040   }
1041 
1042   template <typename T>
1043   static float getValueNorm(T value, float mean, float sigma, float n_sigmas_max = 5) {
1044     const float fixed_value = getValue(value);
1045     const float norm_value = (fixed_value - mean) / sigma;
1046     return std::clamp(norm_value, -n_sigmas_max, n_sigmas_max);
1047   }
1048 
1049   static bool isAbove(double value, double min) { return std::isnormal(value) && value > min; }
1050 
1051   static bool calculateElectronClusterVarsV2(const pat::Electron& ele,
1052                                              float& cc_ele_energy,
1053                                              float& cc_gamma_energy,
1054                                              int& cc_n_gamma) {
1055     cc_ele_energy = cc_gamma_energy = 0;
1056     cc_n_gamma = 0;
1057     const auto& superCluster = ele.superCluster();
1058     if (superCluster.isNonnull() && superCluster.isAvailable() && superCluster->clusters().isNonnull() &&
1059         superCluster->clusters().isAvailable()) {
1060       for (auto iter = superCluster->clustersBegin(); iter != superCluster->clustersEnd(); ++iter) {
1061         const float energy = static_cast<float>((*iter)->energy());
1062         if (iter == superCluster->clustersBegin())
1063           cc_ele_energy += energy;
1064         else {
1065           cc_gamma_energy += energy;
1066           ++cc_n_gamma;
1067         }
1068       }
1069       return true;
1070     } else
1071       return false;
1072   }
1073 
1074   inline void checkInputs(const tensorflow::Tensor& inputs,
1075                           const std::string& block_name,
1076                           int n_inputs,
1077                           const CellGrid* grid = nullptr) const {
1078     if (debug_level >= 1) {
1079       std::cout << "<checkInputs>: block_name = " << block_name << std::endl;
1080       if (block_name == "input_tau") {
1081         for (int input_index = 0; input_index < n_inputs; ++input_index) {
1082           float input = inputs.matrix<float>()(0, input_index);
1083           if (edm::isNotFinite(input)) {
1084             throw cms::Exception("DeepTauId")
1085                 << "in the " << block_name
1086                 << ", input is not finite, i.e. infinite or NaN, for input_index = " << input_index;
1087           }
1088           if (debug_level >= 2) {
1089             std::cout << block_name << "[var = " << input_index << "] = " << std::setprecision(5) << std::fixed << input
1090                       << std::endl;
1091           }
1092         }
1093       } else {
1094         assert(grid);
1095         int n_eta, n_phi;
1096         if (block_name.find("input_inner") != std::string::npos) {
1097           n_eta = 5;
1098           n_phi = 5;
1099         } else if (block_name.find("input_outer") != std::string::npos) {
1100           n_eta = 10;
1101           n_phi = 10;
1102         } else
1103           assert(0);
1104         int eta_phi_index = 0;
1105         for (int eta = -n_eta; eta <= n_eta; ++eta) {
1106           for (int phi = -n_phi; phi <= n_phi; ++phi) {
1107             const CellIndex cell_index{eta, phi};
1108             const auto cell_iter = grid->find(cell_index);
1109             if (cell_iter != grid->end()) {
1110               for (int input_index = 0; input_index < n_inputs; ++input_index) {
1111                 float input = inputs.tensor<float, 4>()(eta_phi_index, 0, 0, input_index);
1112                 if (edm::isNotFinite(input)) {
1113                   throw cms::Exception("DeepTauId")
1114                       << "in the " << block_name << ", input is not finite, i.e. infinite or NaN, for eta = " << eta
1115                       << ", phi = " << phi << ", input_index = " << input_index;
1116                 }
1117                 if (debug_level >= 2) {
1118                   std::cout << block_name << "[eta = " << eta << "][phi = " << phi << "][var = " << input_index
1119                             << "] = " << std::setprecision(5) << std::fixed << input << std::endl;
1120                 }
1121               }
1122               eta_phi_index += 1;
1123             }
1124           }
1125         }
1126       }
1127     }
1128   }
1129 
1130   inline void saveInputs(const tensorflow::Tensor& inputs,
1131                          const std::string& block_name,
1132                          int n_inputs,
1133                          const CellGrid* grid = nullptr) {
1134     if (debug_level >= 1) {
1135       std::cout << "<saveInputs>: block_name = " << block_name << std::endl;
1136     }
1137     if (!is_first_block_)
1138       (*json_file_) << ", ";
1139     (*json_file_) << "\"" << block_name << "\": [";
1140     if (block_name == "input_tau") {
1141       for (int input_index = 0; input_index < n_inputs; ++input_index) {
1142         float input = inputs.matrix<float>()(0, input_index);
1143         if (input_index != 0)
1144           (*json_file_) << ", ";
1145         (*json_file_) << input;
1146       }
1147     } else {
1148       assert(grid);
1149       int n_eta, n_phi;
1150       if (block_name.find("input_inner") != std::string::npos) {
1151         n_eta = 5;
1152         n_phi = 5;
1153       } else if (block_name.find("input_outer") != std::string::npos) {
1154         n_eta = 10;
1155         n_phi = 10;
1156       } else
1157         assert(0);
1158       int eta_phi_index = 0;
1159       for (int eta = -n_eta; eta <= n_eta; ++eta) {
1160         if (eta != -n_eta)
1161           (*json_file_) << ", ";
1162         (*json_file_) << "[";
1163         for (int phi = -n_phi; phi <= n_phi; ++phi) {
1164           if (phi != -n_phi)
1165             (*json_file_) << ", ";
1166           (*json_file_) << "[";
1167           const CellIndex cell_index{eta, phi};
1168           const auto cell_iter = grid->find(cell_index);
1169           for (int input_index = 0; input_index < n_inputs; ++input_index) {
1170             float input = 0.;
1171             if (cell_iter != grid->end()) {
1172               input = inputs.tensor<float, 4>()(eta_phi_index, 0, 0, input_index);
1173             }
1174             if (input_index != 0)
1175               (*json_file_) << ", ";
1176             (*json_file_) << input;
1177           }
1178           if (cell_iter != grid->end()) {
1179             eta_phi_index += 1;
1180           }
1181           (*json_file_) << "]";
1182         }
1183         (*json_file_) << "]";
1184       }
1185     }
1186     (*json_file_) << "]";
1187     is_first_block_ = false;
1188   }
1189 
1190 private:
1191   tensorflow::Tensor getPredictions(edm::Event& event, edm::Handle<TauCollection> taus) override {
1192     // Empty dummy vectors
1193     const std::vector<pat::Electron> electron_collection_default;
1194     const std::vector<pat::Muon> muon_collection_default;
1195     const reco::TauDiscriminatorContainer basicTauDiscriminators_default;
1196     const reco::TauDiscriminatorContainer basicTauDiscriminatorsdR03_default;
1197     const edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>
1198         pfTauTransverseImpactParameters_default;
1199 
1200     const std::vector<pat::Electron>* electron_collection;
1201     const std::vector<pat::Muon>* muon_collection;
1202     const reco::TauDiscriminatorContainer* basicTauDiscriminators;
1203     const reco::TauDiscriminatorContainer* basicTauDiscriminatorsdR03;
1204     const edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>*
1205         pfTauTransverseImpactParameters;
1206 
1207     if (!is_online_) {
1208       electron_collection = &event.get(electrons_token_);
1209       muon_collection = &event.get(muons_token_);
1210       pfTauTransverseImpactParameters = &pfTauTransverseImpactParameters_default;
1211       basicTauDiscriminators = &basicTauDiscriminators_default;
1212       basicTauDiscriminatorsdR03 = &basicTauDiscriminatorsdR03_default;
1213     } else {
1214       electron_collection = &electron_collection_default;
1215       muon_collection = &muon_collection_default;
1216       pfTauTransverseImpactParameters = &event.get(pfTauTransverseImpactParameters_token_);
1217       basicTauDiscriminators = &event.get(basicTauDiscriminators_inputToken_);
1218       basicTauDiscriminatorsdR03 = &event.get(basicTauDiscriminatorsdR03_inputToken_);
1219 
1220       // Get indices for discriminators
1221       if (!discrIndicesMapped_) {
1222         basicDiscrIndexMap_ =
1223             matchDiscriminatorIndices(event, basicTauDiscriminators_inputToken_, requiredBasicDiscriminators_);
1224         basicDiscrdR03IndexMap_ =
1225             matchDiscriminatorIndices(event, basicTauDiscriminatorsdR03_inputToken_, requiredBasicDiscriminatorsdR03_);
1226         discrIndicesMapped_ = true;
1227       }
1228     }
1229 
1230     TauFunc tauIDs = {basicTauDiscriminators,
1231                       basicTauDiscriminatorsdR03,
1232                       pfTauTransverseImpactParameters,
1233                       basicDiscrIndexMap_,
1234                       basicDiscrdR03IndexMap_};
1235 
1236     edm::Handle<edm::View<reco::Candidate>> pfCands;
1237     event.getByToken(pfcandToken_, pfCands);
1238 
1239     edm::Handle<reco::VertexCollection> vertices;
1240     event.getByToken(vtxToken_, vertices);
1241 
1242     edm::Handle<double> rho;
1243     event.getByToken(rho_token_, rho);
1244 
1245     auto const& eventnr = event.id().event();
1246 
1247     tensorflow::Tensor predictions(tensorflow::DT_FLOAT, {static_cast<int>(taus->size()), deep_tau::NumberOfOutputs});
1248 
1249     for (size_t tau_index = 0; tau_index < taus->size(); ++tau_index) {
1250       const edm::RefToBase<reco::BaseTau> tauRef = taus->refAt(tau_index);
1251 
1252       std::vector<tensorflow::Tensor> pred_vector;
1253 
1254       bool passesPrediscriminants;
1255       if (is_online_) {
1256         passesPrediscriminants = tauIDs.passPrediscriminants<std::vector<TauDiscInfo<reco::PFTauDiscriminator>>>(
1257             recoPrediscriminants_, andPrediscriminants_, tauRef);
1258       } else {
1259         passesPrediscriminants = tauIDs.passPrediscriminants<std::vector<TauDiscInfo<pat::PATTauDiscriminator>>>(
1260             patPrediscriminants_, andPrediscriminants_, tauRef);
1261       }
1262 
1263       if (passesPrediscriminants) {
1264         if (version_ == 2) {
1265           if (is_online_) {
1266             getPredictionsV2<reco::PFCandidate, reco::PFTau>(taus->at(tau_index),
1267                                                              tau_index,
1268                                                              tauRef,
1269                                                              electron_collection,
1270                                                              muon_collection,
1271                                                              *pfCands,
1272                                                              vertices->at(0),
1273                                                              *rho,
1274                                                              eventnr,
1275                                                              pred_vector,
1276                                                              tauIDs);
1277           } else
1278             getPredictionsV2<pat::PackedCandidate, pat::Tau>(taus->at(tau_index),
1279                                                              tau_index,
1280                                                              tauRef,
1281                                                              electron_collection,
1282                                                              muon_collection,
1283                                                              *pfCands,
1284                                                              vertices->at(0),
1285                                                              *rho,
1286                                                              eventnr,
1287                                                              pred_vector,
1288                                                              tauIDs);
1289         } else {
1290           throw cms::Exception("DeepTauId") << "version " << version_ << " is not supported.";
1291         }
1292 
1293         for (int k = 0; k < deep_tau::NumberOfOutputs; ++k) {
1294           const float pred = pred_vector[0].flat<float>()(k);
1295           if (!(pred >= 0 && pred <= 1))
1296             throw cms::Exception("DeepTauId")
1297                 << "invalid prediction = " << pred << " for tau_index = " << tau_index << ", pred_index = " << k;
1298           predictions.matrix<float>()(tau_index, k) = pred;
1299         }
1300       } else {
1301         // This else statement was added as a part of the DeepTau@HLT development. It does not affect the current state
1302         // of offline DeepTauId code as there the preselection is not used (it was added in the DeepTau@HLT). It returns
1303         // default values for deepTau score if the preselection failed. Before this statement the values given for this tau
1304         // were random. k == 2 corresponds to the tau score and all other k values to e, mu and jets. By defining in this way
1305         // the final score is -1.
1306         for (int k = 0; k < deep_tau::NumberOfOutputs; ++k) {
1307           predictions.matrix<float>()(tau_index, k) = (k == 2) ? -1.f : 2.f;
1308         }
1309       }
1310     }
1311     return predictions;
1312   }
1313 
1314   template <typename CandidateCastType, typename TauCastType>
1315   void getPredictionsV2(TauCollection::const_reference& tau,
1316                         const size_t tau_index,
1317                         const edm::RefToBase<reco::BaseTau> tau_ref,
1318                         const std::vector<pat::Electron>* electrons,
1319                         const std::vector<pat::Muon>* muons,
1320                         const edm::View<reco::Candidate>& pfCands,
1321                         const reco::Vertex& pv,
1322                         double rho,
1323                         const edm::EventNumber_t& eventnr,
1324                         std::vector<tensorflow::Tensor>& pred_vector,
1325                         TauFunc tau_funcs) {
1326     using namespace dnn_inputs_v2;
1327     if (debug_level >= 2) {
1328       std::cout << "<DeepTauId::getPredictionsV2 (moduleLabel = " << moduleDescription().moduleLabel()
1329                 << ")>:" << std::endl;
1330       std::cout << " tau: pT = " << tau.pt() << ", eta = " << tau.eta() << ", phi = " << tau.phi()
1331                 << ", eventnr = " << eventnr << std::endl;
1332     }
1333     CellGrid inner_grid(number_of_inner_cell, number_of_inner_cell, 0.02, 0.02, disable_CellIndex_workaround_);
1334     CellGrid outer_grid(number_of_outer_cell, number_of_outer_cell, 0.05, 0.05, disable_CellIndex_workaround_);
1335     fillGrids(dynamic_cast<const TauCastType&>(tau), *electrons, inner_grid, outer_grid);
1336     fillGrids(dynamic_cast<const TauCastType&>(tau), *muons, inner_grid, outer_grid);
1337     fillGrids(dynamic_cast<const TauCastType&>(tau), pfCands, inner_grid, outer_grid);
1338 
1339     createTauBlockInputs<CandidateCastType>(
1340         dynamic_cast<const TauCastType&>(tau), tau_index, tau_ref, pv, rho, tau_funcs);
1341     checkInputs(*tauBlockTensor_, "input_tau", static_cast<int>(tauBlockTensor_->shape().dim_size(1)));
1342     createConvFeatures<CandidateCastType>(dynamic_cast<const TauCastType&>(tau),
1343                                           tau_index,
1344                                           tau_ref,
1345                                           pv,
1346                                           rho,
1347                                           electrons,
1348                                           muons,
1349                                           pfCands,
1350                                           inner_grid,
1351                                           tau_funcs,
1352                                           true);
1353     checkInputs(*eGammaTensor_[true], "input_inner_egamma", EgammaBlockInputs::NumberOfInputs, &inner_grid);
1354     checkInputs(*muonTensor_[true], "input_inner_muon", MuonBlockInputs::NumberOfInputs, &inner_grid);
1355     checkInputs(*hadronsTensor_[true], "input_inner_hadrons", HadronBlockInputs::NumberOfInputs, &inner_grid);
1356     createConvFeatures<CandidateCastType>(dynamic_cast<const TauCastType&>(tau),
1357                                           tau_index,
1358                                           tau_ref,
1359                                           pv,
1360                                           rho,
1361                                           electrons,
1362                                           muons,
1363                                           pfCands,
1364                                           outer_grid,
1365                                           tau_funcs,
1366                                           false);
1367     checkInputs(*eGammaTensor_[false], "input_outer_egamma", EgammaBlockInputs::NumberOfInputs, &outer_grid);
1368     checkInputs(*muonTensor_[false], "input_outer_muon", MuonBlockInputs::NumberOfInputs, &outer_grid);
1369     checkInputs(*hadronsTensor_[false], "input_outer_hadrons", HadronBlockInputs::NumberOfInputs, &outer_grid);
1370 
1371     if (save_inputs_) {
1372       std::string json_file_name = "DeepTauId_" + std::to_string(eventnr) + "_" + std::to_string(tau_index) + ".json";
1373       json_file_ = new std::ofstream(json_file_name.data());
1374       is_first_block_ = true;
1375       (*json_file_) << "{";
1376       saveInputs(*tauBlockTensor_, "input_tau", static_cast<int>(tauBlockTensor_->shape().dim_size(1)));
1377       saveInputs(
1378           *eGammaTensor_[true], "input_inner_egamma", dnn_inputs_v2::EgammaBlockInputs::NumberOfInputs, &inner_grid);
1379       saveInputs(*muonTensor_[true], "input_inner_muon", dnn_inputs_v2::MuonBlockInputs::NumberOfInputs, &inner_grid);
1380       saveInputs(
1381           *hadronsTensor_[true], "input_inner_hadrons", dnn_inputs_v2::HadronBlockInputs::NumberOfInputs, &inner_grid);
1382       saveInputs(
1383           *eGammaTensor_[false], "input_outer_egamma", dnn_inputs_v2::EgammaBlockInputs::NumberOfInputs, &outer_grid);
1384       saveInputs(*muonTensor_[false], "input_outer_muon", dnn_inputs_v2::MuonBlockInputs::NumberOfInputs, &outer_grid);
1385       saveInputs(
1386           *hadronsTensor_[false], "input_outer_hadrons", dnn_inputs_v2::HadronBlockInputs::NumberOfInputs, &outer_grid);
1387       (*json_file_) << "}";
1388       delete json_file_;
1389       ++file_counter_;
1390     }
1391 
1392     tensorflow::run(&(cache_->getSession("core")),
1393                     {{"input_tau", *tauBlockTensor_},
1394                      {"input_inner", *convTensor_.at(true)},
1395                      {"input_outer", *convTensor_.at(false)}},
1396                     {"main_output/Softmax"},
1397                     &pred_vector);
1398     if (debug_level >= 1) {
1399       std::cout << "output = { ";
1400       for (int idx = 0; idx < deep_tau::NumberOfOutputs; ++idx) {
1401         if (idx > 0)
1402           std::cout << ", ";
1403         std::string label;
1404         if (idx == 0)
1405           label = "e";
1406         else if (idx == 1)
1407           label = "mu";
1408         else if (idx == 2)
1409           label = "tau";
1410         else if (idx == 3)
1411           label = "jet";
1412         else
1413           assert(0);
1414         std::cout << label << " = " << pred_vector[0].flat<float>()(idx);
1415       }
1416       std::cout << " }" << std::endl;
1417     }
1418   }
1419 
1420   template <typename Collection, typename TauCastType>
1421   void fillGrids(const TauCastType& tau, const Collection& objects, CellGrid& inner_grid, CellGrid& outer_grid) {
1422     static constexpr double outer_dR2 = 0.25;  //0.5^2
1423     const double inner_radius = getInnerSignalConeRadius(tau.polarP4().pt());
1424     const double inner_dR2 = std::pow(inner_radius, 2);
1425 
1426     const auto addObject = [&](size_t n, double deta, double dphi, CellGrid& grid) {
1427       const auto& obj = objects.at(n);
1428       const CellObjectType obj_type = GetCellObjectType(obj);
1429       if (obj_type == CellObjectType::Other)
1430         return;
1431       CellIndex cell_index;
1432       if (grid.tryGetCellIndex(deta, dphi, cell_index)) {
1433         Cell& cell = grid[cell_index];
1434         auto iter = cell.find(obj_type);
1435         if (iter != cell.end()) {
1436           const auto& prev_obj = objects.at(iter->second);
1437           if (obj.polarP4().pt() > prev_obj.polarP4().pt())
1438             iter->second = n;
1439         } else {
1440           cell[obj_type] = n;
1441         }
1442       }
1443     };
1444 
1445     for (size_t n = 0; n < objects.size(); ++n) {
1446       const auto& obj = objects.at(n);
1447       const double deta = obj.polarP4().eta() - tau.polarP4().eta();
1448       const double dphi = reco::deltaPhi(obj.polarP4().phi(), tau.polarP4().phi());
1449       const double dR2 = std::pow(deta, 2) + std::pow(dphi, 2);
1450       if (dR2 < inner_dR2)
1451         addObject(n, deta, dphi, inner_grid);
1452       if (dR2 < outer_dR2)
1453         addObject(n, deta, dphi, outer_grid);
1454     }
1455   }
1456 
1457   tensorflow::Tensor getPartialPredictions(bool is_inner) {
1458     std::vector<tensorflow::Tensor> pred_vector;
1459     if (is_inner) {
1460       tensorflow::run(&(cache_->getSession("inner")),
1461                       {
1462                           {"input_inner_egamma", *eGammaTensor_.at(is_inner)},
1463                           {"input_inner_muon", *muonTensor_.at(is_inner)},
1464                           {"input_inner_hadrons", *hadronsTensor_.at(is_inner)},
1465                       },
1466                       {"inner_all_dropout_4/Identity"},
1467                       &pred_vector);
1468     } else {
1469       tensorflow::run(&(cache_->getSession("outer")),
1470                       {
1471                           {"input_outer_egamma", *eGammaTensor_.at(is_inner)},
1472                           {"input_outer_muon", *muonTensor_.at(is_inner)},
1473                           {"input_outer_hadrons", *hadronsTensor_.at(is_inner)},
1474                       },
1475                       {"outer_all_dropout_4/Identity"},
1476                       &pred_vector);
1477     }
1478     return pred_vector.at(0);
1479   }
1480 
1481   template <typename CandidateCastType, typename TauCastType>
1482   void createConvFeatures(const TauCastType& tau,
1483                           const size_t tau_index,
1484                           const edm::RefToBase<reco::BaseTau> tau_ref,
1485                           const reco::Vertex& pv,
1486                           double rho,
1487                           const std::vector<pat::Electron>* electrons,
1488                           const std::vector<pat::Muon>* muons,
1489                           const edm::View<reco::Candidate>& pfCands,
1490                           const CellGrid& grid,
1491                           TauFunc tau_funcs,
1492                           bool is_inner) {
1493     if (debug_level >= 2) {
1494       std::cout << "<DeepTauId::createConvFeatures (is_inner = " << is_inner << ")>:" << std::endl;
1495     }
1496     tensorflow::Tensor& convTensor = *convTensor_.at(is_inner);
1497     eGammaTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
1498         tensorflow::DT_FLOAT,
1499         tensorflow::TensorShape{
1500             (long long int)grid.num_valid_cells(), 1, 1, dnn_inputs_v2::EgammaBlockInputs::NumberOfInputs});
1501     muonTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
1502         tensorflow::DT_FLOAT,
1503         tensorflow::TensorShape{
1504             (long long int)grid.num_valid_cells(), 1, 1, dnn_inputs_v2::MuonBlockInputs::NumberOfInputs});
1505     hadronsTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
1506         tensorflow::DT_FLOAT,
1507         tensorflow::TensorShape{
1508             (long long int)grid.num_valid_cells(), 1, 1, dnn_inputs_v2::HadronBlockInputs::NumberOfInputs});
1509 
1510     eGammaTensor_[is_inner]->flat<float>().setZero();
1511     muonTensor_[is_inner]->flat<float>().setZero();
1512     hadronsTensor_[is_inner]->flat<float>().setZero();
1513 
1514     unsigned idx = 0;
1515     for (int eta = -grid.maxEtaIndex(); eta <= grid.maxEtaIndex(); ++eta) {
1516       for (int phi = -grid.maxPhiIndex(); phi <= grid.maxPhiIndex(); ++phi) {
1517         if (debug_level >= 2) {
1518           std::cout << "processing ( eta = " << eta << ", phi = " << phi << " )" << std::endl;
1519         }
1520         const CellIndex cell_index{eta, phi};
1521         const auto cell_iter = grid.find(cell_index);
1522         if (cell_iter != grid.end()) {
1523           if (debug_level >= 2) {
1524             std::cout << " creating inputs for ( eta = " << eta << ", phi = " << phi << " ): idx = " << idx
1525                       << std::endl;
1526           }
1527           const Cell& cell = cell_iter->second;
1528           createEgammaBlockInputs<CandidateCastType>(
1529               idx, tau, tau_index, tau_ref, pv, rho, electrons, pfCands, cell, tau_funcs, is_inner);
1530           createMuonBlockInputs<CandidateCastType>(
1531               idx, tau, tau_index, tau_ref, pv, rho, muons, pfCands, cell, tau_funcs, is_inner);
1532           createHadronsBlockInputs<CandidateCastType>(
1533               idx, tau, tau_index, tau_ref, pv, rho, pfCands, cell, tau_funcs, is_inner);
1534           idx += 1;
1535         } else {
1536           if (debug_level >= 2) {
1537             std::cout << " skipping creation of inputs, because ( eta = " << eta << ", phi = " << phi
1538                       << " ) is not in the grid !!" << std::endl;
1539           }
1540         }
1541       }
1542     }
1543 
1544     const auto predTensor = getPartialPredictions(is_inner);
1545     idx = 0;
1546     for (int eta = -grid.maxEtaIndex(); eta <= grid.maxEtaIndex(); ++eta) {
1547       for (int phi = -grid.maxPhiIndex(); phi <= grid.maxPhiIndex(); ++phi) {
1548         const CellIndex cell_index{eta, phi};
1549         const int eta_index = grid.getEtaTensorIndex(cell_index);
1550         const int phi_index = grid.getPhiTensorIndex(cell_index);
1551 
1552         const auto cell_iter = grid.find(cell_index);
1553         if (cell_iter != grid.end()) {
1554           setCellConvFeatures(convTensor, predTensor, idx, eta_index, phi_index);
1555           idx += 1;
1556         } else {
1557           setCellConvFeatures(convTensor, *zeroOutputTensor_[is_inner], 0, eta_index, phi_index);
1558         }
1559       }
1560     }
1561   }
1562 
1563   void setCellConvFeatures(tensorflow::Tensor& convTensor,
1564                            const tensorflow::Tensor& features,
1565                            unsigned batch_idx,
1566                            int eta_index,
1567                            int phi_index) {
1568     for (int n = 0; n < dnn_inputs_v2::number_of_conv_features; ++n) {
1569       convTensor.tensor<float, 4>()(0, eta_index, phi_index, n) = features.tensor<float, 4>()(batch_idx, 0, 0, n);
1570     }
1571   }
1572 
1573   template <typename CandidateCastType, typename TauCastType>
1574   void createTauBlockInputs(const TauCastType& tau,
1575                             const size_t& tau_index,
1576                             const edm::RefToBase<reco::BaseTau> tau_ref,
1577                             const reco::Vertex& pv,
1578                             double rho,
1579                             TauFunc tau_funcs) {
1580     namespace dnn = dnn_inputs_v2::TauBlockInputs;
1581     namespace sc = deep_tau::Scaling;
1582     sc::FeatureT ft = sc::FeatureT::TauFlat;
1583     const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft, false));
1584 
1585     tensorflow::Tensor& inputs = *tauBlockTensor_;
1586     inputs.flat<float>().setZero();
1587 
1588     const auto& get = [&](int var_index) -> float& {
1589       return inputs.matrix<float>()(0, tauInputs_indices_.at(var_index));
1590     };
1591 
1592     auto leadChargedHadrCand = dynamic_cast<const CandidateCastType*>(tau.leadChargedHadrCand().get());
1593 
1594     get(dnn::rho) = sp.scale(rho, tauInputs_indices_[dnn::rho]);
1595     get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), tauInputs_indices_[dnn::tau_pt]);
1596     get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), tauInputs_indices_[dnn::tau_eta]);
1597     if (sub_version_ == 1) {
1598       get(dnn::tau_phi) = getValueLinear(tau.polarP4().phi(), -pi, pi, false);
1599     }
1600     get(dnn::tau_mass) = sp.scale(tau.polarP4().mass(), tauInputs_indices_[dnn::tau_mass]);
1601     get(dnn::tau_E_over_pt) = sp.scale(tau.p4().energy() / tau.p4().pt(), tauInputs_indices_[dnn::tau_E_over_pt]);
1602     get(dnn::tau_charge) = sp.scale(tau.charge(), tauInputs_indices_[dnn::tau_charge]);
1603     get(dnn::tau_n_charged_prongs) = sp.scale(tau.decayMode() / 5 + 1, tauInputs_indices_[dnn::tau_n_charged_prongs]);
1604     get(dnn::tau_n_neutral_prongs) = sp.scale(tau.decayMode() % 5, tauInputs_indices_[dnn::tau_n_neutral_prongs]);
1605     get(dnn::chargedIsoPtSum) =
1606         sp.scale(tau_funcs.getChargedIsoPtSum(tau, tau_ref), tauInputs_indices_[dnn::chargedIsoPtSum]);
1607     get(dnn::chargedIsoPtSumdR03_over_dR05) =
1608         sp.scale(tau_funcs.getChargedIsoPtSumdR03(tau, tau_ref) / tau_funcs.getChargedIsoPtSum(tau, tau_ref),
1609                  tauInputs_indices_[dnn::chargedIsoPtSumdR03_over_dR05]);
1610     if (sub_version_ == 1)
1611       get(dnn::footprintCorrection) =
1612           sp.scale(tau_funcs.getFootprintCorrectiondR03(tau, tau_ref), tauInputs_indices_[dnn::footprintCorrection]);
1613     else if (sub_version_ == 5)
1614       get(dnn::footprintCorrection) =
1615           sp.scale(tau_funcs.getFootprintCorrection(tau, tau_ref), tauInputs_indices_[dnn::footprintCorrection]);
1616 
1617     get(dnn::neutralIsoPtSum) =
1618         sp.scale(tau_funcs.getNeutralIsoPtSum(tau, tau_ref), tauInputs_indices_[dnn::neutralIsoPtSum]);
1619     get(dnn::neutralIsoPtSumWeight_over_neutralIsoPtSum) =
1620         sp.scale(tau_funcs.getNeutralIsoPtSumWeight(tau, tau_ref) / tau_funcs.getNeutralIsoPtSum(tau, tau_ref),
1621                  tauInputs_indices_[dnn::neutralIsoPtSumWeight_over_neutralIsoPtSum]);
1622     get(dnn::neutralIsoPtSumWeightdR03_over_neutralIsoPtSum) =
1623         sp.scale(tau_funcs.getNeutralIsoPtSumdR03Weight(tau, tau_ref) / tau_funcs.getNeutralIsoPtSum(tau, tau_ref),
1624                  tauInputs_indices_[dnn::neutralIsoPtSumWeightdR03_over_neutralIsoPtSum]);
1625     get(dnn::neutralIsoPtSumdR03_over_dR05) =
1626         sp.scale(tau_funcs.getNeutralIsoPtSumdR03(tau, tau_ref) / tau_funcs.getNeutralIsoPtSum(tau, tau_ref),
1627                  tauInputs_indices_[dnn::neutralIsoPtSumdR03_over_dR05]);
1628     get(dnn::photonPtSumOutsideSignalCone) = sp.scale(tau_funcs.getPhotonPtSumOutsideSignalCone(tau, tau_ref),
1629                                                       tauInputs_indices_[dnn::photonPtSumOutsideSignalCone]);
1630     get(dnn::puCorrPtSum) = sp.scale(tau_funcs.getPuCorrPtSum(tau, tau_ref), tauInputs_indices_[dnn::puCorrPtSum]);
1631     // The global PCA coordinates were used as inputs during the NN training, but it was decided to disable
1632     // them for the inference, because modeling of dxy_PCA in MC poorly describes the data, and x and y coordinates
1633     // in data results outside of the expected 5 std. dev. input validity range. On the other hand,
1634     // these coordinates are strongly era-dependent. Kept as comment to document what NN expects.
1635     if (sub_version_ == 1) {
1636       if (!disable_dxy_pca_) {
1637         auto const pca = tau_funcs.getdxyPCA(tau, tau_index);
1638         get(dnn::tau_dxy_pca_x) = sp.scale(pca.x(), tauInputs_indices_[dnn::tau_dxy_pca_x]);
1639         get(dnn::tau_dxy_pca_y) = sp.scale(pca.y(), tauInputs_indices_[dnn::tau_dxy_pca_y]);
1640         get(dnn::tau_dxy_pca_z) = sp.scale(pca.z(), tauInputs_indices_[dnn::tau_dxy_pca_z]);
1641       } else {
1642         get(dnn::tau_dxy_pca_x) = 0;
1643         get(dnn::tau_dxy_pca_y) = 0;
1644         get(dnn::tau_dxy_pca_z) = 0;
1645       }
1646     }
1647 
1648     const bool tau_dxy_valid =
1649         isAbove(tau_funcs.getdxy(tau, tau_index), -10) && isAbove(tau_funcs.getdxyError(tau, tau_index), 0);
1650     if (tau_dxy_valid) {
1651       get(dnn::tau_dxy_valid) = sp.scale(tau_dxy_valid, tauInputs_indices_[dnn::tau_dxy_valid]);
1652       get(dnn::tau_dxy) = sp.scale(tau_funcs.getdxy(tau, tau_index), tauInputs_indices_[dnn::tau_dxy]);
1653       get(dnn::tau_dxy_sig) =
1654           sp.scale(std::abs(tau_funcs.getdxy(tau, tau_index)) / tau_funcs.getdxyError(tau, tau_index),
1655                    tauInputs_indices_[dnn::tau_dxy_sig]);
1656     }
1657     const bool tau_ip3d_valid =
1658         isAbove(tau_funcs.getip3d(tau, tau_index), -10) && isAbove(tau_funcs.getip3dError(tau, tau_index), 0);
1659     if (tau_ip3d_valid) {
1660       get(dnn::tau_ip3d_valid) = sp.scale(tau_ip3d_valid, tauInputs_indices_[dnn::tau_ip3d_valid]);
1661       get(dnn::tau_ip3d) = sp.scale(tau_funcs.getip3d(tau, tau_index), tauInputs_indices_[dnn::tau_ip3d]);
1662       get(dnn::tau_ip3d_sig) =
1663           sp.scale(std::abs(tau_funcs.getip3d(tau, tau_index)) / tau_funcs.getip3dError(tau, tau_index),
1664                    tauInputs_indices_[dnn::tau_ip3d_sig]);
1665     }
1666     if (leadChargedHadrCand) {
1667       const bool hasTrackDetails = candFunc::getHasTrackDetails(*leadChargedHadrCand);
1668       const float tau_dz = (is_online_ && !hasTrackDetails) ? 0 : candFunc::getTauDz(*leadChargedHadrCand);
1669       get(dnn::tau_dz) = sp.scale(tau_dz, tauInputs_indices_[dnn::tau_dz]);
1670       get(dnn::tau_dz_sig_valid) =
1671           sp.scale(candFunc::getTauDZSigValid(*leadChargedHadrCand), tauInputs_indices_[dnn::tau_dz_sig_valid]);
1672       const double dzError = hasTrackDetails ? leadChargedHadrCand->dzError() : -999.;
1673       get(dnn::tau_dz_sig) = sp.scale(std::abs(tau_dz) / dzError, tauInputs_indices_[dnn::tau_dz_sig]);
1674     }
1675     get(dnn::tau_flightLength_x) =
1676         sp.scale(tau_funcs.getFlightLength(tau, tau_index).x(), tauInputs_indices_[dnn::tau_flightLength_x]);
1677     get(dnn::tau_flightLength_y) =
1678         sp.scale(tau_funcs.getFlightLength(tau, tau_index).y(), tauInputs_indices_[dnn::tau_flightLength_y]);
1679     get(dnn::tau_flightLength_z) =
1680         sp.scale(tau_funcs.getFlightLength(tau, tau_index).z(), tauInputs_indices_[dnn::tau_flightLength_z]);
1681     if (sub_version_ == 1)
1682       get(dnn::tau_flightLength_sig) = 0.55756444;  //This value is set due to a bug in the training
1683     else if (sub_version_ == 5)
1684       get(dnn::tau_flightLength_sig) =
1685           sp.scale(tau_funcs.getFlightLengthSig(tau, tau_index), tauInputs_indices_[dnn::tau_flightLength_sig]);
1686 
1687     get(dnn::tau_pt_weighted_deta_strip) = sp.scale(reco::tau::pt_weighted_deta_strip(tau, tau.decayMode()),
1688                                                     tauInputs_indices_[dnn::tau_pt_weighted_deta_strip]);
1689 
1690     get(dnn::tau_pt_weighted_dphi_strip) = sp.scale(reco::tau::pt_weighted_dphi_strip(tau, tau.decayMode()),
1691                                                     tauInputs_indices_[dnn::tau_pt_weighted_dphi_strip]);
1692     get(dnn::tau_pt_weighted_dr_signal) = sp.scale(reco::tau::pt_weighted_dr_signal(tau, tau.decayMode()),
1693                                                    tauInputs_indices_[dnn::tau_pt_weighted_dr_signal]);
1694     get(dnn::tau_pt_weighted_dr_iso) =
1695         sp.scale(reco::tau::pt_weighted_dr_iso(tau, tau.decayMode()), tauInputs_indices_[dnn::tau_pt_weighted_dr_iso]);
1696     get(dnn::tau_leadingTrackNormChi2) =
1697         sp.scale(tau_funcs.getLeadingTrackNormChi2(tau), tauInputs_indices_[dnn::tau_leadingTrackNormChi2]);
1698     const auto eratio = reco::tau::eratio(tau);
1699     const bool tau_e_ratio_valid = std::isnormal(eratio) && eratio > 0.f;
1700     get(dnn::tau_e_ratio_valid) = sp.scale(tau_e_ratio_valid, tauInputs_indices_[dnn::tau_e_ratio_valid]);
1701     get(dnn::tau_e_ratio) = tau_e_ratio_valid ? sp.scale(eratio, tauInputs_indices_[dnn::tau_e_ratio]) : 0.f;
1702     const double gj_angle_diff = calculateGottfriedJacksonAngleDifference(tau, tau_index, tau_funcs);
1703     const bool tau_gj_angle_diff_valid = (std::isnormal(gj_angle_diff) || gj_angle_diff == 0) && gj_angle_diff >= 0;
1704     get(dnn::tau_gj_angle_diff_valid) =
1705         sp.scale(tau_gj_angle_diff_valid, tauInputs_indices_[dnn::tau_gj_angle_diff_valid]);
1706     get(dnn::tau_gj_angle_diff) =
1707         tau_gj_angle_diff_valid ? sp.scale(gj_angle_diff, tauInputs_indices_[dnn::tau_gj_angle_diff]) : 0;
1708     get(dnn::tau_n_photons) = sp.scale(reco::tau::n_photons_total(tau), tauInputs_indices_[dnn::tau_n_photons]);
1709     get(dnn::tau_emFraction) = sp.scale(tau_funcs.getEmFraction(tau), tauInputs_indices_[dnn::tau_emFraction]);
1710 
1711     get(dnn::tau_inside_ecal_crack) =
1712         sp.scale(isInEcalCrack(tau.p4().eta()), tauInputs_indices_[dnn::tau_inside_ecal_crack]);
1713     get(dnn::leadChargedCand_etaAtEcalEntrance_minus_tau_eta) =
1714         sp.scale(tau_funcs.getEtaAtEcalEntrance(tau) - tau.p4().eta(),
1715                  tauInputs_indices_[dnn::leadChargedCand_etaAtEcalEntrance_minus_tau_eta]);
1716   }
1717 
1718   template <typename CandidateCastType, typename TauCastType>
1719   void createEgammaBlockInputs(unsigned idx,
1720                                const TauCastType& tau,
1721                                const size_t tau_index,
1722                                const edm::RefToBase<reco::BaseTau> tau_ref,
1723                                const reco::Vertex& pv,
1724                                double rho,
1725                                const std::vector<pat::Electron>* electrons,
1726                                const edm::View<reco::Candidate>& pfCands,
1727                                const Cell& cell_map,
1728                                TauFunc tau_funcs,
1729                                bool is_inner) {
1730     namespace dnn = dnn_inputs_v2::EgammaBlockInputs;
1731     namespace sc = deep_tau::Scaling;
1732     sc::FeatureT ft_global = sc::FeatureT::GridGlobal;
1733     sc::FeatureT ft_PFe = sc::FeatureT::PfCand_electron;
1734     sc::FeatureT ft_PFg = sc::FeatureT::PfCand_gamma;
1735     sc::FeatureT ft_e = sc::FeatureT::Electron;
1736 
1737     // needed to remap indices from scaling vectors to those from dnn_inputs_v2::EgammaBlockInputs
1738     int PFe_index_offset = scalingParamsMap_->at(std::make_pair(ft_global, false)).mean_.size();
1739     int e_index_offset = PFe_index_offset + scalingParamsMap_->at(std::make_pair(ft_PFe, false)).mean_.size();
1740     int PFg_index_offset = e_index_offset + scalingParamsMap_->at(std::make_pair(ft_e, false)).mean_.size();
1741 
1742     // to account for swapped order of PfCand_gamma and Electron blocks for v2p5 training w.r.t. v2p1
1743     int fill_index_offset_e = 0;
1744     int fill_index_offset_PFg = 0;
1745     if (sub_version_ == 5) {
1746       fill_index_offset_e =
1747           scalingParamsMap_->at(std::make_pair(ft_PFg, false)).mean_.size();  // size of PF gamma features
1748       fill_index_offset_PFg =
1749           -scalingParamsMap_->at(std::make_pair(ft_e, false)).mean_.size();  // size of Electron features
1750     }
1751 
1752     tensorflow::Tensor& inputs = *eGammaTensor_.at(is_inner);
1753 
1754     const auto& get = [&](int var_index) -> float& { return inputs.tensor<float, 4>()(idx, 0, 0, var_index); };
1755 
1756     const bool valid_index_pf_ele = cell_map.count(CellObjectType::PfCand_electron);
1757     const bool valid_index_pf_gamma = cell_map.count(CellObjectType::PfCand_gamma);
1758     const bool valid_index_ele = cell_map.count(CellObjectType::Electron);
1759 
1760     if (!cell_map.empty()) {
1761       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_global, false));
1762       get(dnn::rho) = sp.scale(rho, dnn::rho);
1763       get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), dnn::tau_pt);
1764       get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), dnn::tau_eta);
1765       get(dnn::tau_inside_ecal_crack) = sp.scale(isInEcalCrack(tau.polarP4().eta()), dnn::tau_inside_ecal_crack);
1766     }
1767     if (valid_index_pf_ele) {
1768       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFe, is_inner));
1769       size_t index_pf_ele = cell_map.at(CellObjectType::PfCand_electron);
1770       const auto& ele_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_pf_ele));
1771 
1772       get(dnn::pfCand_ele_valid) = sp.scale(valid_index_pf_ele, dnn::pfCand_ele_valid - PFe_index_offset);
1773       get(dnn::pfCand_ele_rel_pt) =
1774           sp.scale(ele_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_ele_rel_pt - PFe_index_offset);
1775       get(dnn::pfCand_ele_deta) =
1776           sp.scale(ele_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_ele_deta - PFe_index_offset);
1777       get(dnn::pfCand_ele_dphi) =
1778           sp.scale(dPhi(tau.polarP4(), ele_cand.polarP4()), dnn::pfCand_ele_dphi - PFe_index_offset);
1779       get(dnn::pfCand_ele_pvAssociationQuality) = sp.scale<int>(
1780           candFunc::getPvAssocationQuality(ele_cand), dnn::pfCand_ele_pvAssociationQuality - PFe_index_offset);
1781       get(dnn::pfCand_ele_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(ele_cand, 0.9906834f),
1782                                                              dnn::pfCand_ele_puppiWeight - PFe_index_offset)
1783                                                   : sp.scale(candFunc::getPuppiWeight(ele_cand, 0.9669586f),
1784                                                              dnn::pfCand_ele_puppiWeight - PFe_index_offset);
1785       get(dnn::pfCand_ele_charge) = sp.scale(ele_cand.charge(), dnn::pfCand_ele_charge - PFe_index_offset);
1786       get(dnn::pfCand_ele_lostInnerHits) =
1787           sp.scale<int>(candFunc::getLostInnerHits(ele_cand, 0), dnn::pfCand_ele_lostInnerHits - PFe_index_offset);
1788       get(dnn::pfCand_ele_numberOfPixelHits) =
1789           sp.scale(candFunc::getNumberOfPixelHits(ele_cand, 0), dnn::pfCand_ele_numberOfPixelHits - PFe_index_offset);
1790       get(dnn::pfCand_ele_vertex_dx) =
1791           sp.scale(ele_cand.vertex().x() - pv.position().x(), dnn::pfCand_ele_vertex_dx - PFe_index_offset);
1792       get(dnn::pfCand_ele_vertex_dy) =
1793           sp.scale(ele_cand.vertex().y() - pv.position().y(), dnn::pfCand_ele_vertex_dy - PFe_index_offset);
1794       get(dnn::pfCand_ele_vertex_dz) =
1795           sp.scale(ele_cand.vertex().z() - pv.position().z(), dnn::pfCand_ele_vertex_dz - PFe_index_offset);
1796       get(dnn::pfCand_ele_vertex_dx_tauFL) =
1797           sp.scale(ele_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
1798                    dnn::pfCand_ele_vertex_dx_tauFL - PFe_index_offset);
1799       get(dnn::pfCand_ele_vertex_dy_tauFL) =
1800           sp.scale(ele_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
1801                    dnn::pfCand_ele_vertex_dy_tauFL - PFe_index_offset);
1802       get(dnn::pfCand_ele_vertex_dz_tauFL) =
1803           sp.scale(ele_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
1804                    dnn::pfCand_ele_vertex_dz_tauFL - PFe_index_offset);
1805 
1806       const bool hasTrackDetails = candFunc::getHasTrackDetails(ele_cand);
1807       if (hasTrackDetails) {
1808         get(dnn::pfCand_ele_hasTrackDetails) =
1809             sp.scale(hasTrackDetails, dnn::pfCand_ele_hasTrackDetails - PFe_index_offset);
1810         get(dnn::pfCand_ele_dxy) = sp.scale(candFunc::getTauDxy(ele_cand), dnn::pfCand_ele_dxy - PFe_index_offset);
1811         get(dnn::pfCand_ele_dxy_sig) = sp.scale(std::abs(candFunc::getTauDxy(ele_cand)) / ele_cand.dxyError(),
1812                                                 dnn::pfCand_ele_dxy_sig - PFe_index_offset);
1813         get(dnn::pfCand_ele_dz) = sp.scale(candFunc::getTauDz(ele_cand), dnn::pfCand_ele_dz - PFe_index_offset);
1814         get(dnn::pfCand_ele_dz_sig) = sp.scale(std::abs(candFunc::getTauDz(ele_cand)) / ele_cand.dzError(),
1815                                                dnn::pfCand_ele_dz_sig - PFe_index_offset);
1816         get(dnn::pfCand_ele_track_chi2_ndof) =
1817             candFunc::getPseudoTrack(ele_cand).ndof() > 0
1818                 ? sp.scale(candFunc::getPseudoTrack(ele_cand).chi2() / candFunc::getPseudoTrack(ele_cand).ndof(),
1819                            dnn::pfCand_ele_track_chi2_ndof - PFe_index_offset)
1820                 : 0;
1821         get(dnn::pfCand_ele_track_ndof) =
1822             candFunc::getPseudoTrack(ele_cand).ndof() > 0
1823                 ? sp.scale(candFunc::getPseudoTrack(ele_cand).ndof(), dnn::pfCand_ele_track_ndof - PFe_index_offset)
1824                 : 0;
1825       }
1826     }
1827     if (valid_index_pf_gamma) {
1828       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFg, is_inner));
1829       size_t index_pf_gamma = cell_map.at(CellObjectType::PfCand_gamma);
1830       const auto& gamma_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_pf_gamma));
1831 
1832       get(dnn::pfCand_gamma_valid + fill_index_offset_PFg) =
1833           sp.scale(valid_index_pf_gamma, dnn::pfCand_gamma_valid - PFg_index_offset);
1834       get(dnn::pfCand_gamma_rel_pt + fill_index_offset_PFg) =
1835           sp.scale(gamma_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_gamma_rel_pt - PFg_index_offset);
1836       get(dnn::pfCand_gamma_deta + fill_index_offset_PFg) =
1837           sp.scale(gamma_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_gamma_deta - PFg_index_offset);
1838       get(dnn::pfCand_gamma_dphi + fill_index_offset_PFg) =
1839           sp.scale(dPhi(tau.polarP4(), gamma_cand.polarP4()), dnn::pfCand_gamma_dphi - PFg_index_offset);
1840       get(dnn::pfCand_gamma_pvAssociationQuality + fill_index_offset_PFg) = sp.scale<int>(
1841           candFunc::getPvAssocationQuality(gamma_cand), dnn::pfCand_gamma_pvAssociationQuality - PFg_index_offset);
1842       get(dnn::pfCand_gamma_fromPV + fill_index_offset_PFg) =
1843           sp.scale<int>(candFunc::getFromPV(gamma_cand), dnn::pfCand_gamma_fromPV - PFg_index_offset);
1844       get(dnn::pfCand_gamma_puppiWeight + fill_index_offset_PFg) =
1845           is_inner ? sp.scale(candFunc::getPuppiWeight(gamma_cand, 0.9084110f),
1846                               dnn::pfCand_gamma_puppiWeight - PFg_index_offset)
1847                    : sp.scale(candFunc::getPuppiWeight(gamma_cand, 0.4211567f),
1848                               dnn::pfCand_gamma_puppiWeight - PFg_index_offset);
1849       get(dnn::pfCand_gamma_puppiWeightNoLep + fill_index_offset_PFg) =
1850           is_inner ? sp.scale(candFunc::getPuppiWeightNoLep(gamma_cand, 0.8857716f),
1851                               dnn::pfCand_gamma_puppiWeightNoLep - PFg_index_offset)
1852                    : sp.scale(candFunc::getPuppiWeightNoLep(gamma_cand, 0.3822604f),
1853                               dnn::pfCand_gamma_puppiWeightNoLep - PFg_index_offset);
1854       get(dnn::pfCand_gamma_lostInnerHits + fill_index_offset_PFg) =
1855           sp.scale<int>(candFunc::getLostInnerHits(gamma_cand, 0), dnn::pfCand_gamma_lostInnerHits - PFg_index_offset);
1856       get(dnn::pfCand_gamma_numberOfPixelHits + fill_index_offset_PFg) = sp.scale(
1857           candFunc::getNumberOfPixelHits(gamma_cand, 0), dnn::pfCand_gamma_numberOfPixelHits - PFg_index_offset);
1858       get(dnn::pfCand_gamma_vertex_dx + fill_index_offset_PFg) =
1859           sp.scale(gamma_cand.vertex().x() - pv.position().x(), dnn::pfCand_gamma_vertex_dx - PFg_index_offset);
1860       get(dnn::pfCand_gamma_vertex_dy + fill_index_offset_PFg) =
1861           sp.scale(gamma_cand.vertex().y() - pv.position().y(), dnn::pfCand_gamma_vertex_dy - PFg_index_offset);
1862       get(dnn::pfCand_gamma_vertex_dz + fill_index_offset_PFg) =
1863           sp.scale(gamma_cand.vertex().z() - pv.position().z(), dnn::pfCand_gamma_vertex_dz - PFg_index_offset);
1864       get(dnn::pfCand_gamma_vertex_dx_tauFL + fill_index_offset_PFg) =
1865           sp.scale(gamma_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
1866                    dnn::pfCand_gamma_vertex_dx_tauFL - PFg_index_offset);
1867       get(dnn::pfCand_gamma_vertex_dy_tauFL + fill_index_offset_PFg) =
1868           sp.scale(gamma_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
1869                    dnn::pfCand_gamma_vertex_dy_tauFL - PFg_index_offset);
1870       get(dnn::pfCand_gamma_vertex_dz_tauFL + fill_index_offset_PFg) =
1871           sp.scale(gamma_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
1872                    dnn::pfCand_gamma_vertex_dz_tauFL - PFg_index_offset);
1873       const bool hasTrackDetails = candFunc::getHasTrackDetails(gamma_cand);
1874       if (hasTrackDetails) {
1875         get(dnn::pfCand_gamma_hasTrackDetails + fill_index_offset_PFg) =
1876             sp.scale(hasTrackDetails, dnn::pfCand_gamma_hasTrackDetails - PFg_index_offset);
1877         get(dnn::pfCand_gamma_dxy + fill_index_offset_PFg) =
1878             sp.scale(candFunc::getTauDxy(gamma_cand), dnn::pfCand_gamma_dxy - PFg_index_offset);
1879         get(dnn::pfCand_gamma_dxy_sig + fill_index_offset_PFg) =
1880             sp.scale(std::abs(candFunc::getTauDxy(gamma_cand)) / gamma_cand.dxyError(),
1881                      dnn::pfCand_gamma_dxy_sig - PFg_index_offset);
1882         get(dnn::pfCand_gamma_dz + fill_index_offset_PFg) =
1883             sp.scale(candFunc::getTauDz(gamma_cand), dnn::pfCand_gamma_dz - PFg_index_offset);
1884         get(dnn::pfCand_gamma_dz_sig + fill_index_offset_PFg) =
1885             sp.scale(std::abs(candFunc::getTauDz(gamma_cand)) / gamma_cand.dzError(),
1886                      dnn::pfCand_gamma_dz_sig - PFg_index_offset);
1887         get(dnn::pfCand_gamma_track_chi2_ndof + fill_index_offset_PFg) =
1888             candFunc::getPseudoTrack(gamma_cand).ndof() > 0
1889                 ? sp.scale(candFunc::getPseudoTrack(gamma_cand).chi2() / candFunc::getPseudoTrack(gamma_cand).ndof(),
1890                            dnn::pfCand_gamma_track_chi2_ndof - PFg_index_offset)
1891                 : 0;
1892         get(dnn::pfCand_gamma_track_ndof + fill_index_offset_PFg) =
1893             candFunc::getPseudoTrack(gamma_cand).ndof() > 0
1894                 ? sp.scale(candFunc::getPseudoTrack(gamma_cand).ndof(), dnn::pfCand_gamma_track_ndof - PFg_index_offset)
1895                 : 0;
1896       }
1897     }
1898     if (valid_index_ele) {
1899       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_e, is_inner));
1900       size_t index_ele = cell_map.at(CellObjectType::Electron);
1901       const auto& ele = electrons->at(index_ele);
1902 
1903       get(dnn::ele_valid + fill_index_offset_e) = sp.scale(valid_index_ele, dnn::ele_valid - e_index_offset);
1904       get(dnn::ele_rel_pt + fill_index_offset_e) =
1905           sp.scale(ele.polarP4().pt() / tau.polarP4().pt(), dnn::ele_rel_pt - e_index_offset);
1906       get(dnn::ele_deta + fill_index_offset_e) =
1907           sp.scale(ele.polarP4().eta() - tau.polarP4().eta(), dnn::ele_deta - e_index_offset);
1908       get(dnn::ele_dphi + fill_index_offset_e) =
1909           sp.scale(dPhi(tau.polarP4(), ele.polarP4()), dnn::ele_dphi - e_index_offset);
1910 
1911       float cc_ele_energy, cc_gamma_energy;
1912       int cc_n_gamma;
1913       const bool cc_valid = calculateElectronClusterVarsV2(ele, cc_ele_energy, cc_gamma_energy, cc_n_gamma);
1914       if (cc_valid) {
1915         get(dnn::ele_cc_valid + fill_index_offset_e) = sp.scale(cc_valid, dnn::ele_cc_valid - e_index_offset);
1916         get(dnn::ele_cc_ele_rel_energy + fill_index_offset_e) =
1917             sp.scale(cc_ele_energy / ele.polarP4().pt(), dnn::ele_cc_ele_rel_energy - e_index_offset);
1918         get(dnn::ele_cc_gamma_rel_energy + fill_index_offset_e) =
1919             sp.scale(cc_gamma_energy / cc_ele_energy, dnn::ele_cc_gamma_rel_energy - e_index_offset);
1920         get(dnn::ele_cc_n_gamma + fill_index_offset_e) = sp.scale(cc_n_gamma, dnn::ele_cc_n_gamma - e_index_offset);
1921       }
1922       get(dnn::ele_rel_trackMomentumAtVtx + fill_index_offset_e) =
1923           sp.scale(ele.trackMomentumAtVtx().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumAtVtx - e_index_offset);
1924       get(dnn::ele_rel_trackMomentumAtCalo + fill_index_offset_e) = sp.scale(
1925           ele.trackMomentumAtCalo().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumAtCalo - e_index_offset);
1926       get(dnn::ele_rel_trackMomentumOut + fill_index_offset_e) =
1927           sp.scale(ele.trackMomentumOut().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumOut - e_index_offset);
1928       get(dnn::ele_rel_trackMomentumAtEleClus + fill_index_offset_e) = sp.scale(
1929           ele.trackMomentumAtEleClus().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumAtEleClus - e_index_offset);
1930       get(dnn::ele_rel_trackMomentumAtVtxWithConstraint + fill_index_offset_e) =
1931           sp.scale(ele.trackMomentumAtVtxWithConstraint().R() / ele.polarP4().pt(),
1932                    dnn::ele_rel_trackMomentumAtVtxWithConstraint - e_index_offset);
1933       get(dnn::ele_rel_ecalEnergy + fill_index_offset_e) =
1934           sp.scale(ele.ecalEnergy() / ele.polarP4().pt(), dnn::ele_rel_ecalEnergy - e_index_offset);
1935       get(dnn::ele_ecalEnergy_sig + fill_index_offset_e) =
1936           sp.scale(ele.ecalEnergy() / ele.ecalEnergyError(), dnn::ele_ecalEnergy_sig - e_index_offset);
1937       get(dnn::ele_eSuperClusterOverP + fill_index_offset_e) =
1938           sp.scale(ele.eSuperClusterOverP(), dnn::ele_eSuperClusterOverP - e_index_offset);
1939       get(dnn::ele_eSeedClusterOverP + fill_index_offset_e) =
1940           sp.scale(ele.eSeedClusterOverP(), dnn::ele_eSeedClusterOverP - e_index_offset);
1941       get(dnn::ele_eSeedClusterOverPout + fill_index_offset_e) =
1942           sp.scale(ele.eSeedClusterOverPout(), dnn::ele_eSeedClusterOverPout - e_index_offset);
1943       get(dnn::ele_eEleClusterOverPout + fill_index_offset_e) =
1944           sp.scale(ele.eEleClusterOverPout(), dnn::ele_eEleClusterOverPout - e_index_offset);
1945       get(dnn::ele_deltaEtaSuperClusterTrackAtVtx + fill_index_offset_e) =
1946           sp.scale(ele.deltaEtaSuperClusterTrackAtVtx(), dnn::ele_deltaEtaSuperClusterTrackAtVtx - e_index_offset);
1947       get(dnn::ele_deltaEtaSeedClusterTrackAtCalo + fill_index_offset_e) =
1948           sp.scale(ele.deltaEtaSeedClusterTrackAtCalo(), dnn::ele_deltaEtaSeedClusterTrackAtCalo - e_index_offset);
1949       get(dnn::ele_deltaEtaEleClusterTrackAtCalo + fill_index_offset_e) =
1950           sp.scale(ele.deltaEtaEleClusterTrackAtCalo(), dnn::ele_deltaEtaEleClusterTrackAtCalo - e_index_offset);
1951       get(dnn::ele_deltaPhiEleClusterTrackAtCalo + fill_index_offset_e) =
1952           sp.scale(ele.deltaPhiEleClusterTrackAtCalo(), dnn::ele_deltaPhiEleClusterTrackAtCalo - e_index_offset);
1953       get(dnn::ele_deltaPhiSuperClusterTrackAtVtx + fill_index_offset_e) =
1954           sp.scale(ele.deltaPhiSuperClusterTrackAtVtx(), dnn::ele_deltaPhiSuperClusterTrackAtVtx - e_index_offset);
1955       get(dnn::ele_deltaPhiSeedClusterTrackAtCalo + fill_index_offset_e) =
1956           sp.scale(ele.deltaPhiSeedClusterTrackAtCalo(), dnn::ele_deltaPhiSeedClusterTrackAtCalo - e_index_offset);
1957       const bool mva_valid =
1958           (ele.mvaInput().earlyBrem > -2) ||
1959           (year_ !=
1960            2026);  // Known issue that input can be invalid in Phase2 samples (early/lateBrem==-2, hadEnergy==0, sigmaEtaEta/deltaEta==3.40282e+38). Unknown if also in Run2/3, so don't change there
1961       if (mva_valid) {
1962         get(dnn::ele_mvaInput_earlyBrem + fill_index_offset_e) =
1963             sp.scale(ele.mvaInput().earlyBrem, dnn::ele_mvaInput_earlyBrem - e_index_offset);
1964         get(dnn::ele_mvaInput_lateBrem + fill_index_offset_e) =
1965             sp.scale(ele.mvaInput().lateBrem, dnn::ele_mvaInput_lateBrem - e_index_offset);
1966         get(dnn::ele_mvaInput_sigmaEtaEta + fill_index_offset_e) =
1967             sp.scale(ele.mvaInput().sigmaEtaEta, dnn::ele_mvaInput_sigmaEtaEta - e_index_offset);
1968         get(dnn::ele_mvaInput_hadEnergy + fill_index_offset_e) =
1969             sp.scale(ele.mvaInput().hadEnergy, dnn::ele_mvaInput_hadEnergy - e_index_offset);
1970         get(dnn::ele_mvaInput_deltaEta + fill_index_offset_e) =
1971             sp.scale(ele.mvaInput().deltaEta, dnn::ele_mvaInput_deltaEta - e_index_offset);
1972       }
1973       const auto& gsfTrack = ele.gsfTrack();
1974       if (gsfTrack.isNonnull()) {
1975         get(dnn::ele_gsfTrack_normalizedChi2 + fill_index_offset_e) =
1976             sp.scale(gsfTrack->normalizedChi2(), dnn::ele_gsfTrack_normalizedChi2 - e_index_offset);
1977         get(dnn::ele_gsfTrack_numberOfValidHits + fill_index_offset_e) =
1978             sp.scale(gsfTrack->numberOfValidHits(), dnn::ele_gsfTrack_numberOfValidHits - e_index_offset);
1979         get(dnn::ele_rel_gsfTrack_pt + fill_index_offset_e) =
1980             sp.scale(gsfTrack->pt() / ele.polarP4().pt(), dnn::ele_rel_gsfTrack_pt - e_index_offset);
1981         get(dnn::ele_gsfTrack_pt_sig + fill_index_offset_e) =
1982             sp.scale(gsfTrack->pt() / gsfTrack->ptError(), dnn::ele_gsfTrack_pt_sig - e_index_offset);
1983       }
1984       const auto& closestCtfTrack = ele.closestCtfTrackRef();
1985       const bool has_closestCtfTrack = closestCtfTrack.isNonnull();
1986       if (has_closestCtfTrack) {
1987         get(dnn::ele_has_closestCtfTrack + fill_index_offset_e) =
1988             sp.scale(has_closestCtfTrack, dnn::ele_has_closestCtfTrack - e_index_offset);
1989         get(dnn::ele_closestCtfTrack_normalizedChi2 + fill_index_offset_e) =
1990             sp.scale(closestCtfTrack->normalizedChi2(), dnn::ele_closestCtfTrack_normalizedChi2 - e_index_offset);
1991         get(dnn::ele_closestCtfTrack_numberOfValidHits + fill_index_offset_e) =
1992             sp.scale(closestCtfTrack->numberOfValidHits(), dnn::ele_closestCtfTrack_numberOfValidHits - e_index_offset);
1993       }
1994     }
1995   }
1996 
1997   template <typename CandidateCastType, typename TauCastType>
1998   void createMuonBlockInputs(unsigned idx,
1999                              const TauCastType& tau,
2000                              const size_t tau_index,
2001                              const edm::RefToBase<reco::BaseTau> tau_ref,
2002                              const reco::Vertex& pv,
2003                              double rho,
2004                              const std::vector<pat::Muon>* muons,
2005                              const edm::View<reco::Candidate>& pfCands,
2006                              const Cell& cell_map,
2007                              TauFunc tau_funcs,
2008                              bool is_inner) {
2009     namespace dnn = dnn_inputs_v2::MuonBlockInputs;
2010     namespace sc = deep_tau::Scaling;
2011     sc::FeatureT ft_global = sc::FeatureT::GridGlobal;
2012     sc::FeatureT ft_PFmu = sc::FeatureT::PfCand_muon;
2013     sc::FeatureT ft_mu = sc::FeatureT::Muon;
2014 
2015     // needed to remap indices from scaling vectors to those from dnn_inputs_v2::MuonBlockInputs
2016     int PFmu_index_offset = scalingParamsMap_->at(std::make_pair(ft_global, false)).mean_.size();
2017     int mu_index_offset = PFmu_index_offset + scalingParamsMap_->at(std::make_pair(ft_PFmu, false)).mean_.size();
2018 
2019     tensorflow::Tensor& inputs = *muonTensor_.at(is_inner);
2020 
2021     const auto& get = [&](int var_index) -> float& { return inputs.tensor<float, 4>()(idx, 0, 0, var_index); };
2022 
2023     const bool valid_index_pf_muon = cell_map.count(CellObjectType::PfCand_muon);
2024     const bool valid_index_muon = cell_map.count(CellObjectType::Muon);
2025 
2026     if (!cell_map.empty()) {
2027       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_global, false));
2028       get(dnn::rho) = sp.scale(rho, dnn::rho);
2029       get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), dnn::tau_pt);
2030       get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), dnn::tau_eta);
2031       get(dnn::tau_inside_ecal_crack) = sp.scale(isInEcalCrack(tau.polarP4().eta()), dnn::tau_inside_ecal_crack);
2032     }
2033     if (valid_index_pf_muon) {
2034       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFmu, is_inner));
2035       size_t index_pf_muon = cell_map.at(CellObjectType::PfCand_muon);
2036       const auto& muon_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_pf_muon));
2037 
2038       get(dnn::pfCand_muon_valid) = sp.scale(valid_index_pf_muon, dnn::pfCand_muon_valid - PFmu_index_offset);
2039       get(dnn::pfCand_muon_rel_pt) =
2040           sp.scale(muon_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_muon_rel_pt - PFmu_index_offset);
2041       get(dnn::pfCand_muon_deta) =
2042           sp.scale(muon_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_muon_deta - PFmu_index_offset);
2043       get(dnn::pfCand_muon_dphi) =
2044           sp.scale(dPhi(tau.polarP4(), muon_cand.polarP4()), dnn::pfCand_muon_dphi - PFmu_index_offset);
2045       get(dnn::pfCand_muon_pvAssociationQuality) = sp.scale<int>(
2046           candFunc::getPvAssocationQuality(muon_cand), dnn::pfCand_muon_pvAssociationQuality - PFmu_index_offset);
2047       get(dnn::pfCand_muon_fromPV) =
2048           sp.scale<int>(candFunc::getFromPV(muon_cand), dnn::pfCand_muon_fromPV - PFmu_index_offset);
2049       get(dnn::pfCand_muon_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(muon_cand, 0.9786588f),
2050                                                               dnn::pfCand_muon_puppiWeight - PFmu_index_offset)
2051                                                    : sp.scale(candFunc::getPuppiWeight(muon_cand, 0.8132477f),
2052                                                               dnn::pfCand_muon_puppiWeight - PFmu_index_offset);
2053       get(dnn::pfCand_muon_charge) = sp.scale(muon_cand.charge(), dnn::pfCand_muon_charge - PFmu_index_offset);
2054       get(dnn::pfCand_muon_lostInnerHits) =
2055           sp.scale<int>(candFunc::getLostInnerHits(muon_cand, 0), dnn::pfCand_muon_lostInnerHits - PFmu_index_offset);
2056       get(dnn::pfCand_muon_numberOfPixelHits) = sp.scale(candFunc::getNumberOfPixelHits(muon_cand, 0),
2057                                                          dnn::pfCand_muon_numberOfPixelHits - PFmu_index_offset);
2058       get(dnn::pfCand_muon_vertex_dx) =
2059           sp.scale(muon_cand.vertex().x() - pv.position().x(), dnn::pfCand_muon_vertex_dx - PFmu_index_offset);
2060       get(dnn::pfCand_muon_vertex_dy) =
2061           sp.scale(muon_cand.vertex().y() - pv.position().y(), dnn::pfCand_muon_vertex_dy - PFmu_index_offset);
2062       get(dnn::pfCand_muon_vertex_dz) =
2063           sp.scale(muon_cand.vertex().z() - pv.position().z(), dnn::pfCand_muon_vertex_dz - PFmu_index_offset);
2064       get(dnn::pfCand_muon_vertex_dx_tauFL) =
2065           sp.scale(muon_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
2066                    dnn::pfCand_muon_vertex_dx_tauFL - PFmu_index_offset);
2067       get(dnn::pfCand_muon_vertex_dy_tauFL) =
2068           sp.scale(muon_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
2069                    dnn::pfCand_muon_vertex_dy_tauFL - PFmu_index_offset);
2070       get(dnn::pfCand_muon_vertex_dz_tauFL) =
2071           sp.scale(muon_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
2072                    dnn::pfCand_muon_vertex_dz_tauFL - PFmu_index_offset);
2073 
2074       const bool hasTrackDetails = candFunc::getHasTrackDetails(muon_cand);
2075       if (hasTrackDetails) {
2076         get(dnn::pfCand_muon_hasTrackDetails) =
2077             sp.scale(hasTrackDetails, dnn::pfCand_muon_hasTrackDetails - PFmu_index_offset);
2078         get(dnn::pfCand_muon_dxy) = sp.scale(candFunc::getTauDxy(muon_cand), dnn::pfCand_muon_dxy - PFmu_index_offset);
2079         get(dnn::pfCand_muon_dxy_sig) = sp.scale(std::abs(candFunc::getTauDxy(muon_cand)) / muon_cand.dxyError(),
2080                                                  dnn::pfCand_muon_dxy_sig - PFmu_index_offset);
2081         get(dnn::pfCand_muon_dz) = sp.scale(candFunc::getTauDz(muon_cand), dnn::pfCand_muon_dz - PFmu_index_offset);
2082         get(dnn::pfCand_muon_dz_sig) = sp.scale(std::abs(candFunc::getTauDz(muon_cand)) / muon_cand.dzError(),
2083                                                 dnn::pfCand_muon_dz_sig - PFmu_index_offset);
2084         get(dnn::pfCand_muon_track_chi2_ndof) =
2085             candFunc::getPseudoTrack(muon_cand).ndof() > 0
2086                 ? sp.scale(candFunc::getPseudoTrack(muon_cand).chi2() / candFunc::getPseudoTrack(muon_cand).ndof(),
2087                            dnn::pfCand_muon_track_chi2_ndof - PFmu_index_offset)
2088                 : 0;
2089         get(dnn::pfCand_muon_track_ndof) =
2090             candFunc::getPseudoTrack(muon_cand).ndof() > 0
2091                 ? sp.scale(candFunc::getPseudoTrack(muon_cand).ndof(), dnn::pfCand_muon_track_ndof - PFmu_index_offset)
2092                 : 0;
2093       }
2094     }
2095     if (valid_index_muon) {
2096       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_mu, is_inner));
2097       size_t index_muon = cell_map.at(CellObjectType::Muon);
2098       const auto& muon = muons->at(index_muon);
2099 
2100       get(dnn::muon_valid) = sp.scale(valid_index_muon, dnn::muon_valid - mu_index_offset);
2101       get(dnn::muon_rel_pt) = sp.scale(muon.polarP4().pt() / tau.polarP4().pt(), dnn::muon_rel_pt - mu_index_offset);
2102       get(dnn::muon_deta) = sp.scale(muon.polarP4().eta() - tau.polarP4().eta(), dnn::muon_deta - mu_index_offset);
2103       get(dnn::muon_dphi) = sp.scale(dPhi(tau.polarP4(), muon.polarP4()), dnn::muon_dphi - mu_index_offset);
2104       get(dnn::muon_dxy) = sp.scale(muon.dB(pat::Muon::PV2D), dnn::muon_dxy - mu_index_offset);
2105       get(dnn::muon_dxy_sig) =
2106           sp.scale(std::abs(muon.dB(pat::Muon::PV2D)) / muon.edB(pat::Muon::PV2D), dnn::muon_dxy_sig - mu_index_offset);
2107 
2108       const bool normalizedChi2_valid = muon.globalTrack().isNonnull() && muon.normChi2() >= 0;
2109       if (normalizedChi2_valid) {
2110         get(dnn::muon_normalizedChi2_valid) =
2111             sp.scale(normalizedChi2_valid, dnn::muon_normalizedChi2_valid - mu_index_offset);
2112         get(dnn::muon_normalizedChi2) = sp.scale(muon.normChi2(), dnn::muon_normalizedChi2 - mu_index_offset);
2113         if (muon.innerTrack().isNonnull())
2114           get(dnn::muon_numberOfValidHits) =
2115               sp.scale(muon.numberOfValidHits(), dnn::muon_numberOfValidHits - mu_index_offset);
2116       }
2117       get(dnn::muon_segmentCompatibility) =
2118           sp.scale(muon.segmentCompatibility(), dnn::muon_segmentCompatibility - mu_index_offset);
2119       get(dnn::muon_caloCompatibility) =
2120           sp.scale(muon.caloCompatibility(), dnn::muon_caloCompatibility - mu_index_offset);
2121 
2122       const bool pfEcalEnergy_valid = muon.pfEcalEnergy() >= 0;
2123       if (pfEcalEnergy_valid) {
2124         get(dnn::muon_pfEcalEnergy_valid) =
2125             sp.scale(pfEcalEnergy_valid, dnn::muon_pfEcalEnergy_valid - mu_index_offset);
2126         get(dnn::muon_rel_pfEcalEnergy) =
2127             sp.scale(muon.pfEcalEnergy() / muon.polarP4().pt(), dnn::muon_rel_pfEcalEnergy - mu_index_offset);
2128       }
2129 
2130       MuonHitMatchV2 hit_match(muon);
2131       static const std::map<int, std::pair<int, int>> muonMatchHitVars = {
2132           {MuonSubdetId::DT, {dnn::muon_n_matches_DT_1, dnn::muon_n_hits_DT_1}},
2133           {MuonSubdetId::CSC, {dnn::muon_n_matches_CSC_1, dnn::muon_n_hits_CSC_1}},
2134           {MuonSubdetId::RPC, {dnn::muon_n_matches_RPC_1, dnn::muon_n_hits_RPC_1}}};
2135 
2136       for (int subdet : hit_match.MuonHitMatchV2::consideredSubdets()) {
2137         const auto& matchHitVar = muonMatchHitVars.at(subdet);
2138         for (int station = MuonHitMatchV2::first_station_id; station <= MuonHitMatchV2::last_station_id; ++station) {
2139           const unsigned n_matches = hit_match.nMatches(subdet, station);
2140           const unsigned n_hits = hit_match.nHits(subdet, station);
2141           get(matchHitVar.first + station - 1) = sp.scale(n_matches, matchHitVar.first + station - 1 - mu_index_offset);
2142           get(matchHitVar.second + station - 1) = sp.scale(n_hits, matchHitVar.second + station - 1 - mu_index_offset);
2143         }
2144       }
2145     }
2146   }
2147 
2148   template <typename CandidateCastType, typename TauCastType>
2149   void createHadronsBlockInputs(unsigned idx,
2150                                 const TauCastType& tau,
2151                                 const size_t tau_index,
2152                                 const edm::RefToBase<reco::BaseTau> tau_ref,
2153                                 const reco::Vertex& pv,
2154                                 double rho,
2155                                 const edm::View<reco::Candidate>& pfCands,
2156                                 const Cell& cell_map,
2157                                 TauFunc tau_funcs,
2158                                 bool is_inner) {
2159     namespace dnn = dnn_inputs_v2::HadronBlockInputs;
2160     namespace sc = deep_tau::Scaling;
2161     sc::FeatureT ft_global = sc::FeatureT::GridGlobal;
2162     sc::FeatureT ft_PFchH = sc::FeatureT::PfCand_chHad;
2163     sc::FeatureT ft_PFnH = sc::FeatureT::PfCand_nHad;
2164 
2165     // needed to remap indices from scaling vectors to those from dnn_inputs_v2::HadronBlockInputs
2166     int PFchH_index_offset = scalingParamsMap_->at(std::make_pair(ft_global, false)).mean_.size();
2167     int PFnH_index_offset = PFchH_index_offset + scalingParamsMap_->at(std::make_pair(ft_PFchH, false)).mean_.size();
2168 
2169     tensorflow::Tensor& inputs = *hadronsTensor_.at(is_inner);
2170 
2171     const auto& get = [&](int var_index) -> float& { return inputs.tensor<float, 4>()(idx, 0, 0, var_index); };
2172 
2173     const bool valid_chH = cell_map.count(CellObjectType::PfCand_chargedHadron);
2174     const bool valid_nH = cell_map.count(CellObjectType::PfCand_neutralHadron);
2175 
2176     if (!cell_map.empty()) {
2177       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_global, false));
2178       get(dnn::rho) = sp.scale(rho, dnn::rho);
2179       get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), dnn::tau_pt);
2180       get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), dnn::tau_eta);
2181       get(dnn::tau_inside_ecal_crack) = sp.scale(isInEcalCrack(tau.polarP4().eta()), dnn::tau_inside_ecal_crack);
2182     }
2183     if (valid_chH) {
2184       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFchH, is_inner));
2185       size_t index_chH = cell_map.at(CellObjectType::PfCand_chargedHadron);
2186       const auto& chH_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_chH));
2187 
2188       get(dnn::pfCand_chHad_valid) = sp.scale(valid_chH, dnn::pfCand_chHad_valid - PFchH_index_offset);
2189       get(dnn::pfCand_chHad_rel_pt) =
2190           sp.scale(chH_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_chHad_rel_pt - PFchH_index_offset);
2191       get(dnn::pfCand_chHad_deta) =
2192           sp.scale(chH_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_chHad_deta - PFchH_index_offset);
2193       get(dnn::pfCand_chHad_dphi) =
2194           sp.scale(dPhi(tau.polarP4(), chH_cand.polarP4()), dnn::pfCand_chHad_dphi - PFchH_index_offset);
2195       get(dnn::pfCand_chHad_leadChargedHadrCand) =
2196           sp.scale(&chH_cand == dynamic_cast<const CandidateCastType*>(tau.leadChargedHadrCand().get()),
2197                    dnn::pfCand_chHad_leadChargedHadrCand - PFchH_index_offset);
2198       get(dnn::pfCand_chHad_pvAssociationQuality) = sp.scale<int>(
2199           candFunc::getPvAssocationQuality(chH_cand), dnn::pfCand_chHad_pvAssociationQuality - PFchH_index_offset);
2200       get(dnn::pfCand_chHad_fromPV) =
2201           sp.scale<int>(candFunc::getFromPV(chH_cand), dnn::pfCand_chHad_fromPV - PFchH_index_offset);
2202       const float default_chH_pw_inner = 0.7614090f;
2203       const float default_chH_pw_outer = 0.1974930f;
2204       get(dnn::pfCand_chHad_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(chH_cand, default_chH_pw_inner),
2205                                                                dnn::pfCand_chHad_puppiWeight - PFchH_index_offset)
2206                                                     : sp.scale(candFunc::getPuppiWeight(chH_cand, default_chH_pw_outer),
2207                                                                dnn::pfCand_chHad_puppiWeight - PFchH_index_offset);
2208       get(dnn::pfCand_chHad_puppiWeightNoLep) =
2209           is_inner ? sp.scale(candFunc::getPuppiWeightNoLep(chH_cand, default_chH_pw_inner),
2210                               dnn::pfCand_chHad_puppiWeightNoLep - PFchH_index_offset)
2211                    : sp.scale(candFunc::getPuppiWeightNoLep(chH_cand, default_chH_pw_outer),
2212                               dnn::pfCand_chHad_puppiWeightNoLep - PFchH_index_offset);
2213       get(dnn::pfCand_chHad_charge) = sp.scale(chH_cand.charge(), dnn::pfCand_chHad_charge - PFchH_index_offset);
2214       get(dnn::pfCand_chHad_lostInnerHits) =
2215           sp.scale<int>(candFunc::getLostInnerHits(chH_cand, 0), dnn::pfCand_chHad_lostInnerHits - PFchH_index_offset);
2216       get(dnn::pfCand_chHad_numberOfPixelHits) = sp.scale(candFunc::getNumberOfPixelHits(chH_cand, 0),
2217                                                           dnn::pfCand_chHad_numberOfPixelHits - PFchH_index_offset);
2218       get(dnn::pfCand_chHad_vertex_dx) =
2219           sp.scale(chH_cand.vertex().x() - pv.position().x(), dnn::pfCand_chHad_vertex_dx - PFchH_index_offset);
2220       get(dnn::pfCand_chHad_vertex_dy) =
2221           sp.scale(chH_cand.vertex().y() - pv.position().y(), dnn::pfCand_chHad_vertex_dy - PFchH_index_offset);
2222       get(dnn::pfCand_chHad_vertex_dz) =
2223           sp.scale(chH_cand.vertex().z() - pv.position().z(), dnn::pfCand_chHad_vertex_dz - PFchH_index_offset);
2224       get(dnn::pfCand_chHad_vertex_dx_tauFL) =
2225           sp.scale(chH_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
2226                    dnn::pfCand_chHad_vertex_dx_tauFL - PFchH_index_offset);
2227       get(dnn::pfCand_chHad_vertex_dy_tauFL) =
2228           sp.scale(chH_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
2229                    dnn::pfCand_chHad_vertex_dy_tauFL - PFchH_index_offset);
2230       get(dnn::pfCand_chHad_vertex_dz_tauFL) =
2231           sp.scale(chH_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
2232                    dnn::pfCand_chHad_vertex_dz_tauFL - PFchH_index_offset);
2233 
2234       const bool hasTrackDetails = candFunc::getHasTrackDetails(chH_cand);
2235       if (hasTrackDetails) {
2236         get(dnn::pfCand_chHad_hasTrackDetails) =
2237             sp.scale(hasTrackDetails, dnn::pfCand_chHad_hasTrackDetails - PFchH_index_offset);
2238         get(dnn::pfCand_chHad_dxy) =
2239             sp.scale(candFunc::getTauDxy(chH_cand), dnn::pfCand_chHad_dxy - PFchH_index_offset);
2240         get(dnn::pfCand_chHad_dxy_sig) = sp.scale(std::abs(candFunc::getTauDxy(chH_cand)) / chH_cand.dxyError(),
2241                                                   dnn::pfCand_chHad_dxy_sig - PFchH_index_offset);
2242         get(dnn::pfCand_chHad_dz) = sp.scale(candFunc::getTauDz(chH_cand), dnn::pfCand_chHad_dz - PFchH_index_offset);
2243         get(dnn::pfCand_chHad_dz_sig) = sp.scale(std::abs(candFunc::getTauDz(chH_cand)) / chH_cand.dzError(),
2244                                                  dnn::pfCand_chHad_dz_sig - PFchH_index_offset);
2245         get(dnn::pfCand_chHad_track_chi2_ndof) =
2246             candFunc::getPseudoTrack(chH_cand).ndof() > 0
2247                 ? sp.scale(candFunc::getPseudoTrack(chH_cand).chi2() / candFunc::getPseudoTrack(chH_cand).ndof(),
2248                            dnn::pfCand_chHad_track_chi2_ndof - PFchH_index_offset)
2249                 : 0;
2250         get(dnn::pfCand_chHad_track_ndof) =
2251             candFunc::getPseudoTrack(chH_cand).ndof() > 0
2252                 ? sp.scale(candFunc::getPseudoTrack(chH_cand).ndof(), dnn::pfCand_chHad_track_ndof - PFchH_index_offset)
2253                 : 0;
2254       }
2255       float hcal_fraction = candFunc::getHCalFraction(chH_cand, disable_hcalFraction_workaround_);
2256       get(dnn::pfCand_chHad_hcalFraction) =
2257           sp.scale(hcal_fraction, dnn::pfCand_chHad_hcalFraction - PFchH_index_offset);
2258       get(dnn::pfCand_chHad_rawCaloFraction) =
2259           sp.scale(candFunc::getRawCaloFraction(chH_cand), dnn::pfCand_chHad_rawCaloFraction - PFchH_index_offset);
2260     }
2261     if (valid_nH) {
2262       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFnH, is_inner));
2263       size_t index_nH = cell_map.at(CellObjectType::PfCand_neutralHadron);
2264       const auto& nH_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_nH));
2265 
2266       get(dnn::pfCand_nHad_valid) = sp.scale(valid_nH, dnn::pfCand_nHad_valid - PFnH_index_offset);
2267       get(dnn::pfCand_nHad_rel_pt) =
2268           sp.scale(nH_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_nHad_rel_pt - PFnH_index_offset);
2269       get(dnn::pfCand_nHad_deta) =
2270           sp.scale(nH_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_nHad_deta - PFnH_index_offset);
2271       get(dnn::pfCand_nHad_dphi) =
2272           sp.scale(dPhi(tau.polarP4(), nH_cand.polarP4()), dnn::pfCand_nHad_dphi - PFnH_index_offset);
2273       get(dnn::pfCand_nHad_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(nH_cand, 0.9798355f),
2274                                                               dnn::pfCand_nHad_puppiWeight - PFnH_index_offset)
2275                                                    : sp.scale(candFunc::getPuppiWeight(nH_cand, 0.7813260f),
2276                                                               dnn::pfCand_nHad_puppiWeight - PFnH_index_offset);
2277       get(dnn::pfCand_nHad_puppiWeightNoLep) = is_inner
2278                                                    ? sp.scale(candFunc::getPuppiWeightNoLep(nH_cand, 0.9046796f),
2279                                                               dnn::pfCand_nHad_puppiWeightNoLep - PFnH_index_offset)
2280                                                    : sp.scale(candFunc::getPuppiWeightNoLep(nH_cand, 0.6554860f),
2281                                                               dnn::pfCand_nHad_puppiWeightNoLep - PFnH_index_offset);
2282       float hcal_fraction = candFunc::getHCalFraction(nH_cand, disable_hcalFraction_workaround_);
2283       get(dnn::pfCand_nHad_hcalFraction) = sp.scale(hcal_fraction, dnn::pfCand_nHad_hcalFraction - PFnH_index_offset);
2284     }
2285   }
2286 
2287   static void calculateElectronClusterVars(const pat::Electron* ele, float& elecEe, float& elecEgamma) {
2288     if (ele) {
2289       elecEe = elecEgamma = 0;
2290       auto superCluster = ele->superCluster();
2291       if (superCluster.isNonnull() && superCluster.isAvailable() && superCluster->clusters().isNonnull() &&
2292           superCluster->clusters().isAvailable()) {
2293         for (auto iter = superCluster->clustersBegin(); iter != superCluster->clustersEnd(); ++iter) {
2294           const double energy = (*iter)->energy();
2295           if (iter == superCluster->clustersBegin())
2296             elecEe += energy;
2297           else
2298             elecEgamma += energy;
2299         }
2300       }
2301     } else {
2302       elecEe = elecEgamma = default_value;
2303     }
2304   }
2305 
2306   template <typename CandidateCollection, typename TauCastType>
2307   static void processSignalPFComponents(const TauCastType& tau,
2308                                         const CandidateCollection& candidates,
2309                                         LorentzVectorXYZ& p4_inner,
2310                                         LorentzVectorXYZ& p4_outer,
2311                                         float& pt_inner,
2312                                         float& dEta_inner,
2313                                         float& dPhi_inner,
2314                                         float& m_inner,
2315                                         float& pt_outer,
2316                                         float& dEta_outer,
2317                                         float& dPhi_outer,
2318                                         float& m_outer,
2319                                         float& n_inner,
2320                                         float& n_outer) {
2321     p4_inner = LorentzVectorXYZ(0, 0, 0, 0);
2322     p4_outer = LorentzVectorXYZ(0, 0, 0, 0);
2323     n_inner = 0;
2324     n_outer = 0;
2325 
2326     const double innerSigCone_radius = getInnerSignalConeRadius(tau.pt());
2327     for (const auto& cand : candidates) {
2328       const double dR = reco::deltaR(cand->p4(), tau.leadChargedHadrCand()->p4());
2329       const bool isInside_innerSigCone = dR < innerSigCone_radius;
2330       if (isInside_innerSigCone) {
2331         p4_inner += cand->p4();
2332         ++n_inner;
2333       } else {
2334         p4_outer += cand->p4();
2335         ++n_outer;
2336       }
2337     }
2338 
2339     pt_inner = n_inner != 0 ? p4_inner.Pt() : default_value;
2340     dEta_inner = n_inner != 0 ? dEta(p4_inner, tau.p4()) : default_value;
2341     dPhi_inner = n_inner != 0 ? dPhi(p4_inner, tau.p4()) : default_value;
2342     m_inner = n_inner != 0 ? p4_inner.mass() : default_value;
2343 
2344     pt_outer = n_outer != 0 ? p4_outer.Pt() : default_value;
2345     dEta_outer = n_outer != 0 ? dEta(p4_outer, tau.p4()) : default_value;
2346     dPhi_outer = n_outer != 0 ? dPhi(p4_outer, tau.p4()) : default_value;
2347     m_outer = n_outer != 0 ? p4_outer.mass() : default_value;
2348   }
2349 
2350   template <typename CandidateCollection, typename TauCastType>
2351   static void processIsolationPFComponents(const TauCastType& tau,
2352                                            const CandidateCollection& candidates,
2353                                            LorentzVectorXYZ& p4,
2354                                            float& pt,
2355                                            float& d_eta,
2356                                            float& d_phi,
2357                                            float& m,
2358                                            float& n) {
2359     p4 = LorentzVectorXYZ(0, 0, 0, 0);
2360     n = 0;
2361 
2362     for (const auto& cand : candidates) {
2363       p4 += cand->p4();
2364       ++n;
2365     }
2366 
2367     pt = n != 0 ? p4.Pt() : default_value;
2368     d_eta = n != 0 ? dEta(p4, tau.p4()) : default_value;
2369     d_phi = n != 0 ? dPhi(p4, tau.p4()) : default_value;
2370     m = n != 0 ? p4.mass() : default_value;
2371   }
2372 
2373   static double getInnerSignalConeRadius(double pt) {
2374     static constexpr double min_pt = 30., min_radius = 0.05, cone_opening_coef = 3.;
2375     // This is equivalent of the original formula (std::max(std::min(0.1, 3.0/pt), 0.05)
2376     return std::max(cone_opening_coef / std::max(pt, min_pt), min_radius);
2377   }
2378 
2379   // Copied from https://github.com/cms-sw/cmssw/blob/CMSSW_9_4_X/RecoTauTag/RecoTau/plugins/PATTauDiscriminationByMVAIsolationRun2.cc#L218
2380   template <typename TauCastType>
2381   static bool calculateGottfriedJacksonAngleDifference(const TauCastType& tau,
2382                                                        const size_t tau_index,
2383                                                        double& gj_diff,
2384                                                        TauFunc tau_funcs) {
2385     if (tau_funcs.getHasSecondaryVertex(tau, tau_index)) {
2386       static constexpr double mTau = 1.77682;
2387       const double mAOne = tau.p4().M();
2388       const double pAOneMag = tau.p();
2389       const double argumentThetaGJmax = (std::pow(mTau, 2) - std::pow(mAOne, 2)) / (2 * mTau * pAOneMag);
2390       const double argumentThetaGJmeasured = tau.p4().Vect().Dot(tau_funcs.getFlightLength(tau, tau_index)) /
2391                                              (pAOneMag * tau_funcs.getFlightLength(tau, tau_index).R());
2392       if (std::abs(argumentThetaGJmax) <= 1. && std::abs(argumentThetaGJmeasured) <= 1.) {
2393         double thetaGJmax = std::asin(argumentThetaGJmax);
2394         double thetaGJmeasured = std::acos(argumentThetaGJmeasured);
2395         gj_diff = thetaGJmeasured - thetaGJmax;
2396         return true;
2397       }
2398     }
2399     return false;
2400   }
2401 
2402   template <typename TauCastType>
2403   static float calculateGottfriedJacksonAngleDifference(const TauCastType& tau,
2404                                                         const size_t tau_index,
2405                                                         TauFunc tau_funcs) {
2406     double gj_diff;
2407     if (calculateGottfriedJacksonAngleDifference(tau, tau_index, gj_diff, tau_funcs))
2408       return static_cast<float>(gj_diff);
2409     return default_value;
2410   }
2411 
2412   static bool isInEcalCrack(double eta) {
2413     const double abs_eta = std::abs(eta);
2414     return abs_eta > 1.46 && abs_eta < 1.558;
2415   }
2416 
2417   template <typename TauCastType>
2418   static const pat::Electron* findMatchedElectron(const TauCastType& tau,
2419                                                   const std::vector<pat::Electron>* electrons,
2420                                                   double deltaR) {
2421     const double dR2 = deltaR * deltaR;
2422     const pat::Electron* matched_ele = nullptr;
2423     for (const auto& ele : *electrons) {
2424       if (reco::deltaR2(tau.p4(), ele.p4()) < dR2 && (!matched_ele || matched_ele->pt() < ele.pt())) {
2425         matched_ele = &ele;
2426       }
2427     }
2428     return matched_ele;
2429   }
2430 
2431 private:
2432   edm::EDGetTokenT<std::vector<pat::Electron>> electrons_token_;
2433   edm::EDGetTokenT<std::vector<pat::Muon>> muons_token_;
2434   edm::EDGetTokenT<double> rho_token_;
2435   edm::EDGetTokenT<reco::TauDiscriminatorContainer> basicTauDiscriminators_inputToken_;
2436   edm::EDGetTokenT<reco::TauDiscriminatorContainer> basicTauDiscriminatorsdR03_inputToken_;
2437   edm::EDGetTokenT<edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>>
2438       pfTauTransverseImpactParameters_token_;
2439   std::string input_layer_, output_layer_;
2440   const unsigned year_;
2441   const unsigned version_;
2442   const unsigned sub_version_;
2443   const int debug_level;
2444   const bool disable_dxy_pca_;
2445   const bool disable_hcalFraction_workaround_;
2446   const bool disable_CellIndex_workaround_;
2447   std::unique_ptr<tensorflow::Tensor> tauBlockTensor_;
2448   std::array<std::unique_ptr<tensorflow::Tensor>, 2> eGammaTensor_, muonTensor_, hadronsTensor_, convTensor_,
2449       zeroOutputTensor_;
2450   const std::map<std::pair<deep_tau::Scaling::FeatureT, bool>, deep_tau::Scaling::ScalingParams>* scalingParamsMap_;
2451   const bool save_inputs_;
2452   std::ofstream* json_file_;
2453   bool is_first_block_;
2454   int file_counter_;
2455   std::vector<int> tauInputs_indices_;
2456 
2457   //boolean to check if discriminator indices are already mapped
2458   bool discrIndicesMapped_ = false;
2459   std::map<BasicDiscriminator, size_t> basicDiscrIndexMap_;
2460   std::map<BasicDiscriminator, size_t> basicDiscrdR03IndexMap_;
2461 };
2462 
2463 #include "FWCore/Framework/interface/MakerMacros.h"
2464 DEFINE_FWK_MODULE(DeepTauId);