Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-05-31 22:26:06

0001 /*
0002  * \class DeepTauId
0003  *
0004  * Tau identification using Deep NN.
0005  *
0006  * \author Konstantin Androsov, INFN Pisa
0007  *         Christian Veelken, Tallinn
0008  */
0009 
0010 #include "RecoTauTag/RecoTau/interface/DeepTauBase.h"
0011 #include "RecoTauTag/RecoTau/interface/DeepTauScaling.h"
0012 #include "FWCore/Utilities/interface/isFinite.h"
0013 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0014 #include "DataFormats/TauReco/interface/PFTauTransverseImpactParameterAssociation.h"
0015 
0016 #include <fstream>
0017 #include "oneapi/tbb/concurrent_unordered_set.h"
0018 
0019 namespace deep_tau {
0020   constexpr int NumberOfOutputs = 4;
0021 }
0022 
0023 namespace {
0024 
0025   namespace dnn_inputs_v2 {
0026     constexpr int number_of_inner_cell = 11;
0027     constexpr int number_of_outer_cell = 21;
0028     constexpr int number_of_conv_features = 64;
0029     namespace TauBlockInputs {
0030       enum vars {
0031         rho = 0,
0032         tau_pt,
0033         tau_eta,
0034         tau_phi,
0035         tau_mass,
0036         tau_E_over_pt,
0037         tau_charge,
0038         tau_n_charged_prongs,
0039         tau_n_neutral_prongs,
0040         chargedIsoPtSum,
0041         chargedIsoPtSumdR03_over_dR05,
0042         footprintCorrection,
0043         neutralIsoPtSum,
0044         neutralIsoPtSumWeight_over_neutralIsoPtSum,
0045         neutralIsoPtSumWeightdR03_over_neutralIsoPtSum,
0046         neutralIsoPtSumdR03_over_dR05,
0047         photonPtSumOutsideSignalCone,
0048         puCorrPtSum,
0049         tau_dxy_pca_x,
0050         tau_dxy_pca_y,
0051         tau_dxy_pca_z,
0052         tau_dxy_valid,
0053         tau_dxy,
0054         tau_dxy_sig,
0055         tau_ip3d_valid,
0056         tau_ip3d,
0057         tau_ip3d_sig,
0058         tau_dz,
0059         tau_dz_sig_valid,
0060         tau_dz_sig,
0061         tau_flightLength_x,
0062         tau_flightLength_y,
0063         tau_flightLength_z,
0064         tau_flightLength_sig,
0065         tau_pt_weighted_deta_strip,
0066         tau_pt_weighted_dphi_strip,
0067         tau_pt_weighted_dr_signal,
0068         tau_pt_weighted_dr_iso,
0069         tau_leadingTrackNormChi2,
0070         tau_e_ratio_valid,
0071         tau_e_ratio,
0072         tau_gj_angle_diff_valid,
0073         tau_gj_angle_diff,
0074         tau_n_photons,
0075         tau_emFraction,
0076         tau_inside_ecal_crack,
0077         leadChargedCand_etaAtEcalEntrance_minus_tau_eta,
0078         NumberOfInputs
0079       };
0080       std::vector<int> varsToDrop = {
0081           tau_phi, tau_dxy_pca_x, tau_dxy_pca_y, tau_dxy_pca_z};  // indices of vars to be dropped in the full var enum
0082     }                                                             // namespace TauBlockInputs
0083 
0084     namespace EgammaBlockInputs {
0085       enum vars {
0086         rho = 0,
0087         tau_pt,
0088         tau_eta,
0089         tau_inside_ecal_crack,
0090         pfCand_ele_valid,
0091         pfCand_ele_rel_pt,
0092         pfCand_ele_deta,
0093         pfCand_ele_dphi,
0094         pfCand_ele_pvAssociationQuality,
0095         pfCand_ele_puppiWeight,
0096         pfCand_ele_charge,
0097         pfCand_ele_lostInnerHits,
0098         pfCand_ele_numberOfPixelHits,
0099         pfCand_ele_vertex_dx,
0100         pfCand_ele_vertex_dy,
0101         pfCand_ele_vertex_dz,
0102         pfCand_ele_vertex_dx_tauFL,
0103         pfCand_ele_vertex_dy_tauFL,
0104         pfCand_ele_vertex_dz_tauFL,
0105         pfCand_ele_hasTrackDetails,
0106         pfCand_ele_dxy,
0107         pfCand_ele_dxy_sig,
0108         pfCand_ele_dz,
0109         pfCand_ele_dz_sig,
0110         pfCand_ele_track_chi2_ndof,
0111         pfCand_ele_track_ndof,
0112         ele_valid,
0113         ele_rel_pt,
0114         ele_deta,
0115         ele_dphi,
0116         ele_cc_valid,
0117         ele_cc_ele_rel_energy,
0118         ele_cc_gamma_rel_energy,
0119         ele_cc_n_gamma,
0120         ele_rel_trackMomentumAtVtx,
0121         ele_rel_trackMomentumAtCalo,
0122         ele_rel_trackMomentumOut,
0123         ele_rel_trackMomentumAtEleClus,
0124         ele_rel_trackMomentumAtVtxWithConstraint,
0125         ele_rel_ecalEnergy,
0126         ele_ecalEnergy_sig,
0127         ele_eSuperClusterOverP,
0128         ele_eSeedClusterOverP,
0129         ele_eSeedClusterOverPout,
0130         ele_eEleClusterOverPout,
0131         ele_deltaEtaSuperClusterTrackAtVtx,
0132         ele_deltaEtaSeedClusterTrackAtCalo,
0133         ele_deltaEtaEleClusterTrackAtCalo,
0134         ele_deltaPhiEleClusterTrackAtCalo,
0135         ele_deltaPhiSuperClusterTrackAtVtx,
0136         ele_deltaPhiSeedClusterTrackAtCalo,
0137         ele_mvaInput_earlyBrem,
0138         ele_mvaInput_lateBrem,
0139         ele_mvaInput_sigmaEtaEta,
0140         ele_mvaInput_hadEnergy,
0141         ele_mvaInput_deltaEta,
0142         ele_gsfTrack_normalizedChi2,
0143         ele_gsfTrack_numberOfValidHits,
0144         ele_rel_gsfTrack_pt,
0145         ele_gsfTrack_pt_sig,
0146         ele_has_closestCtfTrack,
0147         ele_closestCtfTrack_normalizedChi2,
0148         ele_closestCtfTrack_numberOfValidHits,
0149         pfCand_gamma_valid,
0150         pfCand_gamma_rel_pt,
0151         pfCand_gamma_deta,
0152         pfCand_gamma_dphi,
0153         pfCand_gamma_pvAssociationQuality,
0154         pfCand_gamma_fromPV,
0155         pfCand_gamma_puppiWeight,
0156         pfCand_gamma_puppiWeightNoLep,
0157         pfCand_gamma_lostInnerHits,
0158         pfCand_gamma_numberOfPixelHits,
0159         pfCand_gamma_vertex_dx,
0160         pfCand_gamma_vertex_dy,
0161         pfCand_gamma_vertex_dz,
0162         pfCand_gamma_vertex_dx_tauFL,
0163         pfCand_gamma_vertex_dy_tauFL,
0164         pfCand_gamma_vertex_dz_tauFL,
0165         pfCand_gamma_hasTrackDetails,
0166         pfCand_gamma_dxy,
0167         pfCand_gamma_dxy_sig,
0168         pfCand_gamma_dz,
0169         pfCand_gamma_dz_sig,
0170         pfCand_gamma_track_chi2_ndof,
0171         pfCand_gamma_track_ndof,
0172         NumberOfInputs
0173       };
0174     }
0175 
0176     namespace MuonBlockInputs {
0177       enum vars {
0178         rho = 0,
0179         tau_pt,
0180         tau_eta,
0181         tau_inside_ecal_crack,
0182         pfCand_muon_valid,
0183         pfCand_muon_rel_pt,
0184         pfCand_muon_deta,
0185         pfCand_muon_dphi,
0186         pfCand_muon_pvAssociationQuality,
0187         pfCand_muon_fromPV,
0188         pfCand_muon_puppiWeight,
0189         pfCand_muon_charge,
0190         pfCand_muon_lostInnerHits,
0191         pfCand_muon_numberOfPixelHits,
0192         pfCand_muon_vertex_dx,
0193         pfCand_muon_vertex_dy,
0194         pfCand_muon_vertex_dz,
0195         pfCand_muon_vertex_dx_tauFL,
0196         pfCand_muon_vertex_dy_tauFL,
0197         pfCand_muon_vertex_dz_tauFL,
0198         pfCand_muon_hasTrackDetails,
0199         pfCand_muon_dxy,
0200         pfCand_muon_dxy_sig,
0201         pfCand_muon_dz,
0202         pfCand_muon_dz_sig,
0203         pfCand_muon_track_chi2_ndof,
0204         pfCand_muon_track_ndof,
0205         muon_valid,
0206         muon_rel_pt,
0207         muon_deta,
0208         muon_dphi,
0209         muon_dxy,
0210         muon_dxy_sig,
0211         muon_normalizedChi2_valid,
0212         muon_normalizedChi2,
0213         muon_numberOfValidHits,
0214         muon_segmentCompatibility,
0215         muon_caloCompatibility,
0216         muon_pfEcalEnergy_valid,
0217         muon_rel_pfEcalEnergy,
0218         muon_n_matches_DT_1,
0219         muon_n_matches_DT_2,
0220         muon_n_matches_DT_3,
0221         muon_n_matches_DT_4,
0222         muon_n_matches_CSC_1,
0223         muon_n_matches_CSC_2,
0224         muon_n_matches_CSC_3,
0225         muon_n_matches_CSC_4,
0226         muon_n_matches_RPC_1,
0227         muon_n_matches_RPC_2,
0228         muon_n_matches_RPC_3,
0229         muon_n_matches_RPC_4,
0230         muon_n_hits_DT_1,
0231         muon_n_hits_DT_2,
0232         muon_n_hits_DT_3,
0233         muon_n_hits_DT_4,
0234         muon_n_hits_CSC_1,
0235         muon_n_hits_CSC_2,
0236         muon_n_hits_CSC_3,
0237         muon_n_hits_CSC_4,
0238         muon_n_hits_RPC_1,
0239         muon_n_hits_RPC_2,
0240         muon_n_hits_RPC_3,
0241         muon_n_hits_RPC_4,
0242         NumberOfInputs
0243       };
0244     }
0245 
0246     namespace HadronBlockInputs {
0247       enum vars {
0248         rho = 0,
0249         tau_pt,
0250         tau_eta,
0251         tau_inside_ecal_crack,
0252         pfCand_chHad_valid,
0253         pfCand_chHad_rel_pt,
0254         pfCand_chHad_deta,
0255         pfCand_chHad_dphi,
0256         pfCand_chHad_leadChargedHadrCand,
0257         pfCand_chHad_pvAssociationQuality,
0258         pfCand_chHad_fromPV,
0259         pfCand_chHad_puppiWeight,
0260         pfCand_chHad_puppiWeightNoLep,
0261         pfCand_chHad_charge,
0262         pfCand_chHad_lostInnerHits,
0263         pfCand_chHad_numberOfPixelHits,
0264         pfCand_chHad_vertex_dx,
0265         pfCand_chHad_vertex_dy,
0266         pfCand_chHad_vertex_dz,
0267         pfCand_chHad_vertex_dx_tauFL,
0268         pfCand_chHad_vertex_dy_tauFL,
0269         pfCand_chHad_vertex_dz_tauFL,
0270         pfCand_chHad_hasTrackDetails,
0271         pfCand_chHad_dxy,
0272         pfCand_chHad_dxy_sig,
0273         pfCand_chHad_dz,
0274         pfCand_chHad_dz_sig,
0275         pfCand_chHad_track_chi2_ndof,
0276         pfCand_chHad_track_ndof,
0277         pfCand_chHad_hcalFraction,
0278         pfCand_chHad_rawCaloFraction,
0279         pfCand_nHad_valid,
0280         pfCand_nHad_rel_pt,
0281         pfCand_nHad_deta,
0282         pfCand_nHad_dphi,
0283         pfCand_nHad_puppiWeight,
0284         pfCand_nHad_puppiWeightNoLep,
0285         pfCand_nHad_hcalFraction,
0286         NumberOfInputs
0287       };
0288     }
0289   }  // namespace dnn_inputs_v2
0290 
0291   float getTauID(const pat::Tau& tau, const std::string& tauID, float default_value = -999., bool assert_input = true) {
0292     static tbb::concurrent_unordered_set<std::string> isFirstWarning;
0293     if (tau.isTauIDAvailable(tauID)) {
0294       return tau.tauID(tauID);
0295     } else {
0296       if (assert_input) {
0297         throw cms::Exception("DeepTauId")
0298             << "Exception in <getTauID>: No tauID '" << tauID << "' available in pat::Tau given as function argument.";
0299       }
0300       if (isFirstWarning.insert(tauID).second) {
0301         edm::LogWarning("DeepTauID") << "Warning in <getTauID>: No tauID '" << tauID
0302                                      << "' available in pat::Tau given as function argument."
0303                                      << " Using default_value = " << default_value << " instead." << std::endl;
0304       }
0305       return default_value;
0306     }
0307   }
0308 
0309   struct TauFunc {
0310     const reco::TauDiscriminatorContainer* basicTauDiscriminatorCollection;
0311     const reco::TauDiscriminatorContainer* basicTauDiscriminatordR03Collection;
0312     const edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>*
0313         pfTauTransverseImpactParameters;
0314 
0315     using BasicDiscr = deep_tau::DeepTauBase::BasicDiscriminator;
0316     std::map<BasicDiscr, size_t> indexMap;
0317     std::map<BasicDiscr, size_t> indexMapdR03;
0318 
0319     const float getChargedIsoPtSum(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0320       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::ChargedIsoPtSum));
0321     }
0322     const float getChargedIsoPtSum(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0323       return getTauID(tau, "chargedIsoPtSum");
0324     }
0325     const float getChargedIsoPtSumdR03(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0326       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(indexMapdR03.at(BasicDiscr::ChargedIsoPtSum));
0327     }
0328     const float getChargedIsoPtSumdR03(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0329       return getTauID(tau, "chargedIsoPtSumdR03");
0330     }
0331     const float getFootprintCorrectiondR03(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0332       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(
0333           indexMapdR03.at(BasicDiscr::FootprintCorrection));
0334     }
0335     const float getFootprintCorrectiondR03(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0336       return getTauID(tau, "footprintCorrectiondR03");
0337     }
0338     const float getFootprintCorrection(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0339       return getTauID(tau, "footprintCorrection");
0340     }
0341     const float getNeutralIsoPtSum(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0342       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::NeutralIsoPtSum));
0343     }
0344     const float getNeutralIsoPtSum(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0345       return getTauID(tau, "neutralIsoPtSum");
0346     }
0347     const float getNeutralIsoPtSumdR03(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0348       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(indexMapdR03.at(BasicDiscr::NeutralIsoPtSum));
0349     }
0350     const float getNeutralIsoPtSumdR03(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0351       return getTauID(tau, "neutralIsoPtSumdR03");
0352     }
0353     const float getNeutralIsoPtSumWeight(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0354       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::NeutralIsoPtSumWeight));
0355     }
0356     const float getNeutralIsoPtSumWeight(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0357       return getTauID(tau, "neutralIsoPtSumWeight");
0358     }
0359     const float getNeutralIsoPtSumdR03Weight(const reco::PFTau& tau,
0360                                              const edm::RefToBase<reco::BaseTau> tau_ref) const {
0361       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(
0362           indexMapdR03.at(BasicDiscr::NeutralIsoPtSumWeight));
0363     }
0364     const float getNeutralIsoPtSumdR03Weight(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0365       return getTauID(tau, "neutralIsoPtSumWeightdR03");
0366     }
0367     const float getPhotonPtSumOutsideSignalCone(const reco::PFTau& tau,
0368                                                 const edm::RefToBase<reco::BaseTau> tau_ref) const {
0369       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(
0370           indexMap.at(BasicDiscr::PhotonPtSumOutsideSignalCone));
0371     }
0372     const float getPhotonPtSumOutsideSignalCone(const pat::Tau& tau,
0373                                                 const edm::RefToBase<reco::BaseTau> tau_ref) const {
0374       return getTauID(tau, "photonPtSumOutsideSignalCone");
0375     }
0376     const float getPhotonPtSumOutsideSignalConedR03(const reco::PFTau& tau,
0377                                                     const edm::RefToBase<reco::BaseTau> tau_ref) const {
0378       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(
0379           indexMapdR03.at(BasicDiscr::PhotonPtSumOutsideSignalCone));
0380     }
0381     const float getPhotonPtSumOutsideSignalConedR03(const pat::Tau& tau,
0382                                                     const edm::RefToBase<reco::BaseTau> tau_ref) const {
0383       return getTauID(tau, "photonPtSumOutsideSignalConedR03");
0384     }
0385     const float getPuCorrPtSum(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0386       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::PUcorrPtSum));
0387     }
0388     const float getPuCorrPtSum(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0389       return getTauID(tau, "puCorrPtSum");
0390     }
0391 
0392     auto getdxyPCA(const reco::PFTau& tau, const size_t tau_index) const {
0393       return pfTauTransverseImpactParameters->value(tau_index)->dxy_PCA();
0394     }
0395     auto getdxyPCA(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy_PCA(); }
0396     auto getdxy(const reco::PFTau& tau, const size_t tau_index) const {
0397       return pfTauTransverseImpactParameters->value(tau_index)->dxy();
0398     }
0399     auto getdxy(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy(); }
0400     auto getdxyError(const reco::PFTau& tau, const size_t tau_index) const {
0401       return pfTauTransverseImpactParameters->value(tau_index)->dxy_error();
0402     }
0403     auto getdxyError(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy_error(); }
0404     auto getdxySig(const reco::PFTau& tau, const size_t tau_index) const {
0405       return pfTauTransverseImpactParameters->value(tau_index)->dxy_Sig();
0406     }
0407     auto getdxySig(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy_Sig(); }
0408     auto getip3d(const reco::PFTau& tau, const size_t tau_index) const {
0409       return pfTauTransverseImpactParameters->value(tau_index)->ip3d();
0410     }
0411     auto getip3d(const pat::Tau& tau, const size_t tau_index) const { return tau.ip3d(); }
0412     auto getip3dError(const reco::PFTau& tau, const size_t tau_index) const {
0413       return pfTauTransverseImpactParameters->value(tau_index)->ip3d_error();
0414     }
0415     auto getip3dError(const pat::Tau& tau, const size_t tau_index) const { return tau.ip3d_error(); }
0416     auto getip3dSig(const reco::PFTau& tau, const size_t tau_index) const {
0417       return pfTauTransverseImpactParameters->value(tau_index)->ip3d_Sig();
0418     }
0419     auto getip3dSig(const pat::Tau& tau, const size_t tau_index) const { return tau.ip3d_Sig(); }
0420     auto getHasSecondaryVertex(const reco::PFTau& tau, const size_t tau_index) const {
0421       return pfTauTransverseImpactParameters->value(tau_index)->hasSecondaryVertex();
0422     }
0423     auto getHasSecondaryVertex(const pat::Tau& tau, const size_t tau_index) const { return tau.hasSecondaryVertex(); }
0424     auto getFlightLength(const reco::PFTau& tau, const size_t tau_index) const {
0425       return pfTauTransverseImpactParameters->value(tau_index)->flightLength();
0426     }
0427     auto getFlightLength(const pat::Tau& tau, const size_t tau_index) const { return tau.flightLength(); }
0428     auto getFlightLengthSig(const reco::PFTau& tau, const size_t tau_index) const {
0429       return pfTauTransverseImpactParameters->value(tau_index)->flightLengthSig();
0430     }
0431     auto getFlightLengthSig(const pat::Tau& tau, const size_t tau_index) const { return tau.flightLengthSig(); }
0432 
0433     auto getLeadingTrackNormChi2(const reco::PFTau& tau) { return reco::tau::lead_track_chi2(tau); }
0434     auto getLeadingTrackNormChi2(const pat::Tau& tau) { return tau.leadingTrackNormChi2(); }
0435     auto getEmFraction(const pat::Tau& tau) { return tau.emFraction_MVA(); }
0436     auto getEmFraction(const reco::PFTau& tau) { return tau.emFraction(); }
0437     auto getEtaAtEcalEntrance(const pat::Tau& tau) { return tau.etaAtEcalEntranceLeadChargedCand(); }
0438     auto getEtaAtEcalEntrance(const reco::PFTau& tau) {
0439       return tau.leadPFChargedHadrCand()->positionAtECALEntrance().eta();
0440     }
0441     auto getEcalEnergyLeadingChargedHadr(const reco::PFTau& tau) { return tau.leadPFChargedHadrCand()->ecalEnergy(); }
0442     auto getEcalEnergyLeadingChargedHadr(const pat::Tau& tau) { return tau.ecalEnergyLeadChargedHadrCand(); }
0443     auto getHcalEnergyLeadingChargedHadr(const reco::PFTau& tau) { return tau.leadPFChargedHadrCand()->hcalEnergy(); }
0444     auto getHcalEnergyLeadingChargedHadr(const pat::Tau& tau) { return tau.hcalEnergyLeadChargedHadrCand(); }
0445 
0446     template <typename PreDiscrType>
0447     bool passPrediscriminants(const PreDiscrType prediscriminants,
0448                               const size_t andPrediscriminants,
0449                               const edm::RefToBase<reco::BaseTau> tau_ref) {
0450       bool passesPrediscriminants = (andPrediscriminants ? 1 : 0);
0451       // check tau passes prediscriminants
0452       size_t nPrediscriminants = prediscriminants.size();
0453       for (size_t iDisc = 0; iDisc < nPrediscriminants; ++iDisc) {
0454         // current discriminant result for this tau
0455         double discResult = (*prediscriminants[iDisc].handle)[tau_ref];
0456         uint8_t thisPasses = (discResult > prediscriminants[iDisc].cut) ? 1 : 0;
0457 
0458         // if we are using the AND option, as soon as one fails,
0459         // the result is FAIL and we can quit looping.
0460         // if we are using the OR option as soon as one passes,
0461         // the result is pass and we can quit looping
0462 
0463         // truth table
0464         //        |   result (thisPasses)
0465         //        |     F     |     T
0466         //-----------------------------------
0467         // AND(T) | res=fails |  continue
0468         //        |  break    |
0469         //-----------------------------------
0470         // OR (F) |  continue | res=passes
0471         //        |           |  break
0472 
0473         if (thisPasses ^ andPrediscriminants)  //XOR
0474         {
0475           passesPrediscriminants = (andPrediscriminants ? 0 : 1);  //NOR
0476           break;
0477         }
0478       }
0479       return passesPrediscriminants;
0480     }
0481   };
0482 
0483   namespace candFunc {
0484     auto getTauDz(const reco::PFCandidate& cand) { return cand.bestTrack()->dz(); }
0485     auto getTauDz(const pat::PackedCandidate& cand) { return cand.dz(); }
0486     auto getTauDZSigValid(const reco::PFCandidate& cand) {
0487       return cand.bestTrack() != nullptr && std::isnormal(cand.bestTrack()->dz()) && std::isnormal(cand.dzError()) &&
0488              cand.dzError() > 0;
0489     }
0490     auto getTauDZSigValid(const pat::PackedCandidate& cand) {
0491       return cand.hasTrackDetails() && std::isnormal(cand.dz()) && std::isnormal(cand.dzError()) && cand.dzError() > 0;
0492     }
0493     auto getTauDxy(const reco::PFCandidate& cand) { return cand.bestTrack()->dxy(); }
0494     auto getTauDxy(const pat::PackedCandidate& cand) { return cand.dxy(); }
0495     auto getPvAssocationQuality(const reco::PFCandidate& cand) { return 0.7013f; }
0496     auto getPvAssocationQuality(const pat::PackedCandidate& cand) { return cand.pvAssociationQuality(); }
0497     auto getPuppiWeight(const reco::PFCandidate& cand, const float aod_value) { return aod_value; }
0498     auto getPuppiWeight(const pat::PackedCandidate& cand, const float aod_value) { return cand.puppiWeight(); }
0499     auto getPuppiWeightNoLep(const reco::PFCandidate& cand, const float aod_value) { return aod_value; }
0500     auto getPuppiWeightNoLep(const pat::PackedCandidate& cand, const float aod_value) {
0501       return cand.puppiWeightNoLep();
0502     }
0503     auto getLostInnerHits(const reco::PFCandidate& cand, float default_value) {
0504       return cand.bestTrack() != nullptr
0505                  ? cand.bestTrack()->hitPattern().numberOfLostHits(reco::HitPattern::MISSING_INNER_HITS)
0506                  : default_value;
0507     }
0508     auto getLostInnerHits(const pat::PackedCandidate& cand, float default_value) { return cand.lostInnerHits(); }
0509     auto getNumberOfPixelHits(const reco::PFCandidate& cand, float default_value) {
0510       return cand.bestTrack() != nullptr
0511                  ? cand.bestTrack()->hitPattern().numberOfLostHits(reco::HitPattern::MISSING_INNER_HITS)
0512                  : default_value;
0513     }
0514     auto getNumberOfPixelHits(const pat::PackedCandidate& cand, float default_value) {
0515       return cand.numberOfPixelHits();
0516     }
0517     auto getHasTrackDetails(const reco::PFCandidate& cand) { return cand.bestTrack() != nullptr; }
0518     auto getHasTrackDetails(const pat::PackedCandidate& cand) { return cand.hasTrackDetails(); }
0519     auto getPseudoTrack(const reco::PFCandidate& cand) { return *cand.bestTrack(); }
0520     auto getPseudoTrack(const pat::PackedCandidate& cand) { return cand.pseudoTrack(); }
0521     auto getFromPV(const reco::PFCandidate& cand) { return 0.9994f; }
0522     auto getFromPV(const pat::PackedCandidate& cand) { return cand.fromPV(); }
0523     auto getHCalFraction(const reco::PFCandidate& cand, bool disable_hcalFraction_workaround) {
0524       return cand.rawHcalEnergy() / (cand.rawHcalEnergy() + cand.rawEcalEnergy());
0525     }
0526     auto getHCalFraction(const pat::PackedCandidate& cand, bool disable_hcalFraction_workaround) {
0527       float hcal_fraction = 0.;
0528       if (disable_hcalFraction_workaround) {
0529         // CV: use consistent definition for pfCand_chHad_hcalFraction
0530         //     in DeepTauId.cc code and in TauMLTools/Production/plugins/TauTupleProducer.cc
0531         hcal_fraction = cand.hcalFraction();
0532       } else {
0533         // CV: backwards compatibility with DeepTau training v2p1 used during Run 2
0534         if (cand.pdgId() == 1 || cand.pdgId() == 130) {
0535           hcal_fraction = cand.hcalFraction();
0536         } else if (cand.isIsolatedChargedHadron()) {
0537           hcal_fraction = cand.rawHcalFraction();
0538         }
0539       }
0540       return hcal_fraction;
0541     }
0542     auto getRawCaloFraction(const reco::PFCandidate& cand) {
0543       return (cand.rawEcalEnergy() + cand.rawHcalEnergy()) / cand.energy();
0544     }
0545     auto getRawCaloFraction(const pat::PackedCandidate& cand) { return cand.rawCaloFraction(); }
0546   };  // namespace candFunc
0547 
0548   template <typename LVector1, typename LVector2>
0549   float dEta(const LVector1& p4, const LVector2& tau_p4) {
0550     return static_cast<float>(p4.eta() - tau_p4.eta());
0551   }
0552 
0553   template <typename LVector1, typename LVector2>
0554   float dPhi(const LVector1& p4_1, const LVector2& p4_2) {
0555     return static_cast<float>(reco::deltaPhi(p4_2.phi(), p4_1.phi()));
0556   }
0557 
0558   struct MuonHitMatchV2 {
0559     static constexpr size_t n_muon_stations = 4;
0560     static constexpr int first_station_id = 1;
0561     static constexpr int last_station_id = first_station_id + n_muon_stations - 1;
0562     using CountArray = std::array<unsigned, n_muon_stations>;
0563     using CountMap = std::map<int, CountArray>;
0564 
0565     const std::vector<int>& consideredSubdets() {
0566       static const std::vector<int> subdets = {MuonSubdetId::DT, MuonSubdetId::CSC, MuonSubdetId::RPC};
0567       return subdets;
0568     }
0569 
0570     const std::string& subdetName(int subdet) {
0571       static const std::map<int, std::string> subdet_names = {
0572           {MuonSubdetId::DT, "DT"}, {MuonSubdetId::CSC, "CSC"}, {MuonSubdetId::RPC, "RPC"}};
0573       if (!subdet_names.count(subdet))
0574         throw cms::Exception("MuonHitMatch") << "Subdet name for subdet id " << subdet << " not found.";
0575       return subdet_names.at(subdet);
0576     }
0577 
0578     size_t getStationIndex(int station, bool throw_exception) const {
0579       if (station < first_station_id || station > last_station_id) {
0580         if (throw_exception)
0581           throw cms::Exception("MuonHitMatch") << "Station id is out of range";
0582         return std::numeric_limits<size_t>::max();
0583       }
0584       return static_cast<size_t>(station - 1);
0585     }
0586 
0587     MuonHitMatchV2(const pat::Muon& muon) {
0588       for (int subdet : consideredSubdets()) {
0589         n_matches[subdet].fill(0);
0590         n_hits[subdet].fill(0);
0591       }
0592 
0593       countMatches(muon, n_matches);
0594       countHits(muon, n_hits);
0595     }
0596 
0597     void countMatches(const pat::Muon& muon, CountMap& n_matches) {
0598       for (const auto& segment : muon.matches()) {
0599         if (segment.segmentMatches.empty() && segment.rpcMatches.empty())
0600           continue;
0601         if (n_matches.count(segment.detector())) {
0602           const size_t station_index = getStationIndex(segment.station(), true);
0603           ++n_matches.at(segment.detector()).at(station_index);
0604         }
0605       }
0606     }
0607 
0608     void countHits(const pat::Muon& muon, CountMap& n_hits) {
0609       if (muon.outerTrack().isNonnull()) {
0610         const auto& hit_pattern = muon.outerTrack()->hitPattern();
0611         for (int hit_index = 0; hit_index < hit_pattern.numberOfAllHits(reco::HitPattern::TRACK_HITS); ++hit_index) {
0612           auto hit_id = hit_pattern.getHitPattern(reco::HitPattern::TRACK_HITS, hit_index);
0613           if (hit_id == 0)
0614             break;
0615           if (hit_pattern.muonHitFilter(hit_id) && (hit_pattern.getHitType(hit_id) == TrackingRecHit::valid ||
0616                                                     hit_pattern.getHitType(hit_id) == TrackingRecHit::bad)) {
0617             const size_t station_index = getStationIndex(hit_pattern.getMuonStation(hit_id), false);
0618             if (station_index < n_muon_stations) {
0619               CountArray* muon_n_hits = nullptr;
0620               if (hit_pattern.muonDTHitFilter(hit_id))
0621                 muon_n_hits = &n_hits.at(MuonSubdetId::DT);
0622               else if (hit_pattern.muonCSCHitFilter(hit_id))
0623                 muon_n_hits = &n_hits.at(MuonSubdetId::CSC);
0624               else if (hit_pattern.muonRPCHitFilter(hit_id))
0625                 muon_n_hits = &n_hits.at(MuonSubdetId::RPC);
0626 
0627               if (muon_n_hits)
0628                 ++muon_n_hits->at(station_index);
0629             }
0630           }
0631         }
0632       }
0633     }
0634 
0635     unsigned nMatches(int subdet, int station) const {
0636       if (!n_matches.count(subdet))
0637         throw cms::Exception("MuonHitMatch") << "Subdet " << subdet << " not found.";
0638       const size_t station_index = getStationIndex(station, true);
0639       return n_matches.at(subdet).at(station_index);
0640     }
0641 
0642     unsigned nHits(int subdet, int station) const {
0643       if (!n_hits.count(subdet))
0644         throw cms::Exception("MuonHitMatch") << "Subdet " << subdet << " not found.";
0645       const size_t station_index = getStationIndex(station, true);
0646       return n_hits.at(subdet).at(station_index);
0647     }
0648 
0649     unsigned countMuonStationsWithMatches(int first_station, int last_station) const {
0650       static const std::map<int, std::vector<bool>> masks = {
0651           {MuonSubdetId::DT, {false, false, false, false}},
0652           {MuonSubdetId::CSC, {true, false, false, false}},
0653           {MuonSubdetId::RPC, {false, false, false, false}},
0654       };
0655       const size_t first_station_index = getStationIndex(first_station, true);
0656       const size_t last_station_index = getStationIndex(last_station, true);
0657       unsigned cnt = 0;
0658       for (size_t n = first_station_index; n <= last_station_index; ++n) {
0659         for (const auto& match : n_matches) {
0660           if (!masks.at(match.first).at(n) && match.second.at(n) > 0)
0661             ++cnt;
0662         }
0663       }
0664       return cnt;
0665     }
0666 
0667     unsigned countMuonStationsWithHits(int first_station, int last_station) const {
0668       static const std::map<int, std::vector<bool>> masks = {
0669           {MuonSubdetId::DT, {false, false, false, false}},
0670           {MuonSubdetId::CSC, {false, false, false, false}},
0671           {MuonSubdetId::RPC, {false, false, false, false}},
0672       };
0673 
0674       const size_t first_station_index = getStationIndex(first_station, true);
0675       const size_t last_station_index = getStationIndex(last_station, true);
0676       unsigned cnt = 0;
0677       for (size_t n = first_station_index; n <= last_station_index; ++n) {
0678         for (const auto& hit : n_hits) {
0679           if (!masks.at(hit.first).at(n) && hit.second.at(n) > 0)
0680             ++cnt;
0681         }
0682       }
0683       return cnt;
0684     }
0685 
0686   private:
0687     CountMap n_matches, n_hits;
0688   };
0689 
0690   enum class CellObjectType {
0691     PfCand_electron,
0692     PfCand_muon,
0693     PfCand_chargedHadron,
0694     PfCand_neutralHadron,
0695     PfCand_gamma,
0696     Electron,
0697     Muon,
0698     Other
0699   };
0700 
0701   template <typename Object>
0702   CellObjectType GetCellObjectType(const Object&);
0703   template <>
0704   CellObjectType GetCellObjectType(const pat::Electron&) {
0705     return CellObjectType::Electron;
0706   }
0707   template <>
0708   CellObjectType GetCellObjectType(const pat::Muon&) {
0709     return CellObjectType::Muon;
0710   }
0711 
0712   template <>
0713   CellObjectType GetCellObjectType(reco::Candidate const& cand) {
0714     static const std::map<int, CellObjectType> obj_types = {{11, CellObjectType::PfCand_electron},
0715                                                             {13, CellObjectType::PfCand_muon},
0716                                                             {22, CellObjectType::PfCand_gamma},
0717                                                             {130, CellObjectType::PfCand_neutralHadron},
0718                                                             {211, CellObjectType::PfCand_chargedHadron}};
0719 
0720     auto iter = obj_types.find(std::abs(cand.pdgId()));
0721     if (iter == obj_types.end())
0722       return CellObjectType::Other;
0723     return iter->second;
0724   }
0725 
0726   using Cell = std::map<CellObjectType, size_t>;
0727   struct CellIndex {
0728     int eta, phi;
0729 
0730     bool operator<(const CellIndex& other) const {
0731       if (eta != other.eta)
0732         return eta < other.eta;
0733       return phi < other.phi;
0734     }
0735   };
0736 
0737   class CellGrid {
0738   public:
0739     using Map = std::map<CellIndex, Cell>;
0740     using const_iterator = Map::const_iterator;
0741 
0742     CellGrid(unsigned n_cells_eta,
0743              unsigned n_cells_phi,
0744              double cell_size_eta,
0745              double cell_size_phi,
0746              bool disable_CellIndex_workaround)
0747         : nCellsEta(n_cells_eta),
0748           nCellsPhi(n_cells_phi),
0749           nTotal(nCellsEta * nCellsPhi),
0750           cellSizeEta(cell_size_eta),
0751           cellSizePhi(cell_size_phi),
0752           disable_CellIndex_workaround_(disable_CellIndex_workaround) {
0753       if (nCellsEta % 2 != 1 || nCellsEta < 1)
0754         throw cms::Exception("DeepTauId") << "Invalid number of eta cells.";
0755       if (nCellsPhi % 2 != 1 || nCellsPhi < 1)
0756         throw cms::Exception("DeepTauId") << "Invalid number of phi cells.";
0757       if (cellSizeEta <= 0 || cellSizePhi <= 0)
0758         throw cms::Exception("DeepTauId") << "Invalid cell size.";
0759     }
0760 
0761     int maxEtaIndex() const { return static_cast<int>((nCellsEta - 1) / 2); }
0762     int maxPhiIndex() const { return static_cast<int>((nCellsPhi - 1) / 2); }
0763     double maxDeltaEta() const { return cellSizeEta * (0.5 + maxEtaIndex()); }
0764     double maxDeltaPhi() const { return cellSizePhi * (0.5 + maxPhiIndex()); }
0765     int getEtaTensorIndex(const CellIndex& cellIndex) const { return cellIndex.eta + maxEtaIndex(); }
0766     int getPhiTensorIndex(const CellIndex& cellIndex) const { return cellIndex.phi + maxPhiIndex(); }
0767 
0768     bool tryGetCellIndex(double deltaEta, double deltaPhi, CellIndex& cellIndex) const {
0769       const auto getCellIndex = [this](double x, double maxX, double size, int& index) {
0770         const double absX = std::abs(x);
0771         if (absX > maxX)
0772           return false;
0773         double absIndex;
0774         if (disable_CellIndex_workaround_) {
0775           // CV: use consistent definition for CellIndex
0776           //     in DeepTauId.cc code and new DeepTau trainings
0777           absIndex = std::floor(absX / size + 0.5);
0778         } else {
0779           // CV: backwards compatibility with DeepTau training v2p1 used during Run 2
0780           absIndex = std::floor(std::abs(absX / size - 0.5));
0781         }
0782         index = static_cast<int>(std::copysign(absIndex, x));
0783         return true;
0784       };
0785 
0786       return getCellIndex(deltaEta, maxDeltaEta(), cellSizeEta, cellIndex.eta) &&
0787              getCellIndex(deltaPhi, maxDeltaPhi(), cellSizePhi, cellIndex.phi);
0788     }
0789 
0790     size_t num_valid_cells() const { return cells.size(); }
0791     Cell& operator[](const CellIndex& cellIndex) { return cells[cellIndex]; }
0792     const Cell& at(const CellIndex& cellIndex) const { return cells.at(cellIndex); }
0793     size_t count(const CellIndex& cellIndex) const { return cells.count(cellIndex); }
0794     const_iterator find(const CellIndex& cellIndex) const { return cells.find(cellIndex); }
0795     const_iterator begin() const { return cells.begin(); }
0796     const_iterator end() const { return cells.end(); }
0797 
0798   public:
0799     const unsigned nCellsEta, nCellsPhi, nTotal;
0800     const double cellSizeEta, cellSizePhi;
0801 
0802   private:
0803     std::map<CellIndex, Cell> cells;
0804     const bool disable_CellIndex_workaround_;
0805   };
0806 }  // anonymous namespace
0807 
0808 using bd = deep_tau::DeepTauBase::BasicDiscriminator;
0809 const std::map<bd, std::string> deep_tau::DeepTauBase::stringFromDiscriminator_{
0810     {bd::ChargedIsoPtSum, "ChargedIsoPtSum"},
0811     {bd::NeutralIsoPtSum, "NeutralIsoPtSum"},
0812     {bd::NeutralIsoPtSumWeight, "NeutralIsoPtSumWeight"},
0813     {bd::FootprintCorrection, "TauFootprintCorrection"},
0814     {bd::PhotonPtSumOutsideSignalCone, "PhotonPtSumOutsideSignalCone"},
0815     {bd::PUcorrPtSum, "PUcorrPtSum"}};
0816 const std::vector<bd> deep_tau::DeepTauBase::requiredBasicDiscriminators_ = {bd::ChargedIsoPtSum,
0817                                                                              bd::NeutralIsoPtSum,
0818                                                                              bd::NeutralIsoPtSumWeight,
0819                                                                              bd::PhotonPtSumOutsideSignalCone,
0820                                                                              bd::PUcorrPtSum};
0821 const std::vector<bd> deep_tau::DeepTauBase::requiredBasicDiscriminatorsdR03_ = {bd::ChargedIsoPtSum,
0822                                                                                  bd::NeutralIsoPtSum,
0823                                                                                  bd::NeutralIsoPtSumWeight,
0824                                                                                  bd::PhotonPtSumOutsideSignalCone,
0825                                                                                  bd::FootprintCorrection};
0826 
0827 class DeepTauId : public deep_tau::DeepTauBase {
0828 public:
0829   static constexpr float default_value = -999.;
0830 
0831   static const OutputCollection& GetOutputs() {
0832     static constexpr size_t e_index = 0, mu_index = 1, tau_index = 2, jet_index = 3;
0833     static const OutputCollection outputs_ = {
0834         {"VSe", Output({tau_index}, {e_index, tau_index})},
0835         {"VSmu", Output({tau_index}, {mu_index, tau_index})},
0836         {"VSjet", Output({tau_index}, {jet_index, tau_index})},
0837     };
0838     return outputs_;
0839   }
0840 
0841   const std::map<BasicDiscriminator, size_t> matchDiscriminatorIndices(
0842       edm::Event& event,
0843       edm::EDGetTokenT<reco::TauDiscriminatorContainer> discriminatorContainerToken,
0844       std::vector<BasicDiscriminator> requiredDiscr) {
0845     std::map<std::string, size_t> discrIndexMapStr;
0846     auto const aHandle = event.getHandle(discriminatorContainerToken);
0847     auto const aProv = aHandle.provenance();
0848     if (aProv == nullptr)
0849       aHandle.whyFailed()->raise();
0850     const auto& psetsFromProvenance = edm::parameterSet(aProv->stable(), event.processHistory());
0851     auto const idlist = psetsFromProvenance.getParameter<std::vector<edm::ParameterSet>>("IDdefinitions");
0852     for (size_t j = 0; j < idlist.size(); ++j) {
0853       std::string idname = idlist[j].getParameter<std::string>("IDname");
0854       if (discrIndexMapStr.count(idname)) {
0855         throw cms::Exception("DeepTauId")
0856             << "basic discriminator " << idname << " appears more than once in the input.";
0857       }
0858       discrIndexMapStr[idname] = j;
0859     }
0860 
0861     //translate to a map of <BasicDiscriminator, index> and check if all discriminators are present
0862     std::map<BasicDiscriminator, size_t> discrIndexMap;
0863     for (size_t i = 0; i < requiredDiscr.size(); i++) {
0864       if (discrIndexMapStr.find(stringFromDiscriminator_.at(requiredDiscr[i])) == discrIndexMapStr.end())
0865         throw cms::Exception("DeepTauId") << "Basic Discriminator " << stringFromDiscriminator_.at(requiredDiscr[i])
0866                                           << " was not provided in the config file.";
0867       else
0868         discrIndexMap[requiredDiscr[i]] = discrIndexMapStr[stringFromDiscriminator_.at(requiredDiscr[i])];
0869     }
0870     return discrIndexMap;
0871   }
0872 
0873   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0874     edm::ParameterSetDescription desc;
0875     desc.add<edm::InputTag>("electrons", edm::InputTag("slimmedElectrons"));
0876     desc.add<edm::InputTag>("muons", edm::InputTag("slimmedMuons"));
0877     desc.add<edm::InputTag>("taus", edm::InputTag("slimmedTaus"));
0878     desc.add<edm::InputTag>("pfcands", edm::InputTag("packedPFCandidates"));
0879     desc.add<edm::InputTag>("vertices", edm::InputTag("offlineSlimmedPrimaryVertices"));
0880     desc.add<edm::InputTag>("rho", edm::InputTag("fixedGridRhoAll"));
0881     desc.add<std::vector<std::string>>("graph_file",
0882                                        {"RecoTauTag/TrainingFiles/data/DeepTauId/deepTau_2017v2p6_e6.pb"});
0883     desc.add<bool>("mem_mapped", false);
0884     desc.add<unsigned>("version", 2);
0885     desc.add<unsigned>("sub_version", 1);
0886     desc.add<int>("debug_level", 0);
0887     desc.add<bool>("disable_dxy_pca", false);
0888     desc.add<bool>("disable_hcalFraction_workaround", false);
0889     desc.add<bool>("disable_CellIndex_workaround", false);
0890     desc.add<bool>("save_inputs", false);
0891     desc.add<bool>("is_online", false);
0892 
0893     desc.add<std::vector<std::string>>("VSeWP", {"-1."});
0894     desc.add<std::vector<std::string>>("VSmuWP", {"-1."});
0895     desc.add<std::vector<std::string>>("VSjetWP", {"-1."});
0896 
0897     desc.addUntracked<edm::InputTag>("basicTauDiscriminators", edm::InputTag("basicTauDiscriminators"));
0898     desc.addUntracked<edm::InputTag>("basicTauDiscriminatorsdR03", edm::InputTag("basicTauDiscriminatorsdR03"));
0899     desc.add<edm::InputTag>("pfTauTransverseImpactParameters", edm::InputTag("hpsPFTauTransverseImpactParameters"));
0900 
0901     {
0902       edm::ParameterSetDescription pset_Prediscriminants;
0903       pset_Prediscriminants.add<std::string>("BooleanOperator", "and");
0904       {
0905         edm::ParameterSetDescription psd1;
0906         psd1.add<double>("cut");
0907         psd1.add<edm::InputTag>("Producer");
0908         pset_Prediscriminants.addOptional<edm::ParameterSetDescription>("decayMode", psd1);
0909       }
0910       desc.add<edm::ParameterSetDescription>("Prediscriminants", pset_Prediscriminants);
0911     }
0912 
0913     descriptions.add("DeepTau", desc);
0914   }
0915 
0916 public:
0917   explicit DeepTauId(const edm::ParameterSet& cfg, const deep_tau::DeepTauCache* cache)
0918       : DeepTauBase(cfg, GetOutputs(), cache),
0919         electrons_token_(consumes<std::vector<pat::Electron>>(cfg.getParameter<edm::InputTag>("electrons"))),
0920         muons_token_(consumes<std::vector<pat::Muon>>(cfg.getParameter<edm::InputTag>("muons"))),
0921         rho_token_(consumes<double>(cfg.getParameter<edm::InputTag>("rho"))),
0922         basicTauDiscriminators_inputToken_(consumes<reco::TauDiscriminatorContainer>(
0923             cfg.getUntrackedParameter<edm::InputTag>("basicTauDiscriminators"))),
0924         basicTauDiscriminatorsdR03_inputToken_(consumes<reco::TauDiscriminatorContainer>(
0925             cfg.getUntrackedParameter<edm::InputTag>("basicTauDiscriminatorsdR03"))),
0926         pfTauTransverseImpactParameters_token_(
0927             consumes<edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>>(
0928                 cfg.getParameter<edm::InputTag>("pfTauTransverseImpactParameters"))),
0929         version_(cfg.getParameter<unsigned>("version")),
0930         sub_version_(cfg.getParameter<unsigned>("sub_version")),
0931         debug_level(cfg.getParameter<int>("debug_level")),
0932         disable_dxy_pca_(cfg.getParameter<bool>("disable_dxy_pca")),
0933         disable_hcalFraction_workaround_(cfg.getParameter<bool>("disable_hcalFraction_workaround")),
0934         disable_CellIndex_workaround_(cfg.getParameter<bool>("disable_CellIndex_workaround")),
0935         save_inputs_(cfg.getParameter<bool>("save_inputs")),
0936         json_file_(nullptr),
0937         file_counter_(0) {
0938     if (version_ == 2) {
0939       using namespace dnn_inputs_v2;
0940       namespace sc = deep_tau::Scaling;
0941       tauInputs_indices_.resize(TauBlockInputs::NumberOfInputs);
0942       std::iota(std::begin(tauInputs_indices_), std::end(tauInputs_indices_), 0);
0943 
0944       if (sub_version_ == 1) {
0945         tauBlockTensor_ = std::make_unique<tensorflow::Tensor>(
0946             tensorflow::DT_FLOAT, tensorflow::TensorShape{1, TauBlockInputs::NumberOfInputs});
0947         scalingParamsMap_ = &sc::scalingParamsMap_v2p1;
0948       } else if (sub_version_ == 5) {
0949         std::sort(TauBlockInputs::varsToDrop.begin(), TauBlockInputs::varsToDrop.end());
0950         for (auto v : TauBlockInputs::varsToDrop) {
0951           tauInputs_indices_.at(v) = -1;  // set index to -1
0952           for (std::size_t i = v + 1; i < TauBlockInputs::NumberOfInputs; ++i)
0953             tauInputs_indices_.at(i) -= 1;  // shift all the following indices by 1
0954         }
0955         tauBlockTensor_ = std::make_unique<tensorflow::Tensor>(
0956             tensorflow::DT_FLOAT,
0957             tensorflow::TensorShape{1,
0958                                     static_cast<int>(TauBlockInputs::NumberOfInputs) -
0959                                         static_cast<int>(TauBlockInputs::varsToDrop.size())});
0960         scalingParamsMap_ = &sc::scalingParamsMap_v2p5;
0961       } else
0962         throw cms::Exception("DeepTauId") << "subversion " << sub_version_ << " is not supported.";
0963 
0964       std::map<std::vector<bool>, std::vector<sc::FeatureT>> GridFeatureTypes_map = {
0965           {{false}, {sc::FeatureT::TauFlat, sc::FeatureT::GridGlobal}},  // feature types without inner/outer grid split
0966           {{false, true},
0967            {sc::FeatureT::PfCand_electron,
0968             sc::FeatureT::PfCand_muon,  // feature types with inner/outer grid split
0969             sc::FeatureT::PfCand_chHad,
0970             sc::FeatureT::PfCand_nHad,
0971             sc::FeatureT::PfCand_gamma,
0972             sc::FeatureT::Electron,
0973             sc::FeatureT::Muon}}};
0974 
0975       // check that sizes of mean/std/lim_min/lim_max vectors are equal between each other
0976       for (const auto& p : GridFeatureTypes_map) {
0977         for (auto is_inner : p.first) {
0978           for (auto featureType : p.second) {
0979             const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(featureType, is_inner));
0980             if (!(sp.mean_.size() == sp.std_.size() && sp.mean_.size() == sp.lim_min_.size() &&
0981                   sp.mean_.size() == sp.lim_max_.size()))
0982               throw cms::Exception("DeepTauId") << "sizes of scaling parameter vectors do not match between each other";
0983           }
0984         }
0985       }
0986 
0987       for (size_t n = 0; n < 2; ++n) {
0988         const bool is_inner = n == 0;
0989         const auto n_cells = is_inner ? number_of_inner_cell : number_of_outer_cell;
0990         eGammaTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
0991             tensorflow::DT_FLOAT, tensorflow::TensorShape{1, 1, 1, EgammaBlockInputs::NumberOfInputs});
0992         muonTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
0993             tensorflow::DT_FLOAT, tensorflow::TensorShape{1, 1, 1, MuonBlockInputs::NumberOfInputs});
0994         hadronsTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
0995             tensorflow::DT_FLOAT, tensorflow::TensorShape{1, 1, 1, HadronBlockInputs::NumberOfInputs});
0996         convTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
0997             tensorflow::DT_FLOAT, tensorflow::TensorShape{1, n_cells, n_cells, number_of_conv_features});
0998         zeroOutputTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
0999             tensorflow::DT_FLOAT, tensorflow::TensorShape{1, 1, 1, number_of_conv_features});
1000 
1001         eGammaTensor_[is_inner]->flat<float>().setZero();
1002         muonTensor_[is_inner]->flat<float>().setZero();
1003         hadronsTensor_[is_inner]->flat<float>().setZero();
1004 
1005         setCellConvFeatures(*zeroOutputTensor_[is_inner], getPartialPredictions(is_inner), 0, 0, 0);
1006       }
1007     } else {
1008       throw cms::Exception("DeepTauId") << "version " << version_ << " is not supported.";
1009     }
1010   }
1011 
1012   static std::unique_ptr<deep_tau::DeepTauCache> initializeGlobalCache(const edm::ParameterSet& cfg) {
1013     return DeepTauBase::initializeGlobalCache(cfg);
1014   }
1015 
1016   static void globalEndJob(const deep_tau::DeepTauCache* cache_) { return DeepTauBase::globalEndJob(cache_); }
1017 
1018 private:
1019   static constexpr float pi = M_PI;
1020 
1021   template <typename T>
1022   static float getValue(T value) {
1023     return std::isnormal(value) ? static_cast<float>(value) : 0.f;
1024   }
1025 
1026   template <typename T>
1027   static float getValueLinear(T value, float min_value, float max_value, bool positive) {
1028     const float fixed_value = getValue(value);
1029     const float clamped_value = std::clamp(fixed_value, min_value, max_value);
1030     float transformed_value = (clamped_value - min_value) / (max_value - min_value);
1031     if (!positive)
1032       transformed_value = transformed_value * 2 - 1;
1033     return transformed_value;
1034   }
1035 
1036   template <typename T>
1037   static float getValueNorm(T value, float mean, float sigma, float n_sigmas_max = 5) {
1038     const float fixed_value = getValue(value);
1039     const float norm_value = (fixed_value - mean) / sigma;
1040     return std::clamp(norm_value, -n_sigmas_max, n_sigmas_max);
1041   }
1042 
1043   static bool isAbove(double value, double min) { return std::isnormal(value) && value > min; }
1044 
1045   static bool calculateElectronClusterVarsV2(const pat::Electron& ele,
1046                                              float& cc_ele_energy,
1047                                              float& cc_gamma_energy,
1048                                              int& cc_n_gamma) {
1049     cc_ele_energy = cc_gamma_energy = 0;
1050     cc_n_gamma = 0;
1051     const auto& superCluster = ele.superCluster();
1052     if (superCluster.isNonnull() && superCluster.isAvailable() && superCluster->clusters().isNonnull() &&
1053         superCluster->clusters().isAvailable()) {
1054       for (auto iter = superCluster->clustersBegin(); iter != superCluster->clustersEnd(); ++iter) {
1055         const float energy = static_cast<float>((*iter)->energy());
1056         if (iter == superCluster->clustersBegin())
1057           cc_ele_energy += energy;
1058         else {
1059           cc_gamma_energy += energy;
1060           ++cc_n_gamma;
1061         }
1062       }
1063       return true;
1064     } else
1065       return false;
1066   }
1067 
1068   inline void checkInputs(const tensorflow::Tensor& inputs,
1069                           const std::string& block_name,
1070                           int n_inputs,
1071                           const CellGrid* grid = nullptr) const {
1072     if (debug_level >= 1) {
1073       std::cout << "<checkInputs>: block_name = " << block_name << std::endl;
1074       if (block_name == "input_tau") {
1075         for (int input_index = 0; input_index < n_inputs; ++input_index) {
1076           float input = inputs.matrix<float>()(0, input_index);
1077           if (edm::isNotFinite(input)) {
1078             throw cms::Exception("DeepTauId")
1079                 << "in the " << block_name
1080                 << ", input is not finite, i.e. infinite or NaN, for input_index = " << input_index;
1081           }
1082           if (debug_level >= 2) {
1083             std::cout << block_name << "[var = " << input_index << "] = " << std::setprecision(5) << std::fixed << input
1084                       << std::endl;
1085           }
1086         }
1087       } else {
1088         assert(grid);
1089         int n_eta, n_phi;
1090         if (block_name.find("input_inner") != std::string::npos) {
1091           n_eta = 5;
1092           n_phi = 5;
1093         } else if (block_name.find("input_outer") != std::string::npos) {
1094           n_eta = 10;
1095           n_phi = 10;
1096         } else
1097           assert(0);
1098         int eta_phi_index = 0;
1099         for (int eta = -n_eta; eta <= n_eta; ++eta) {
1100           for (int phi = -n_phi; phi <= n_phi; ++phi) {
1101             const CellIndex cell_index{eta, phi};
1102             const auto cell_iter = grid->find(cell_index);
1103             if (cell_iter != grid->end()) {
1104               for (int input_index = 0; input_index < n_inputs; ++input_index) {
1105                 float input = inputs.tensor<float, 4>()(eta_phi_index, 0, 0, input_index);
1106                 if (edm::isNotFinite(input)) {
1107                   throw cms::Exception("DeepTauId")
1108                       << "in the " << block_name << ", input is not finite, i.e. infinite or NaN, for eta = " << eta
1109                       << ", phi = " << phi << ", input_index = " << input_index;
1110                 }
1111                 if (debug_level >= 2) {
1112                   std::cout << block_name << "[eta = " << eta << "][phi = " << phi << "][var = " << input_index
1113                             << "] = " << std::setprecision(5) << std::fixed << input << std::endl;
1114                 }
1115               }
1116               eta_phi_index += 1;
1117             }
1118           }
1119         }
1120       }
1121     }
1122   }
1123 
1124   inline void saveInputs(const tensorflow::Tensor& inputs,
1125                          const std::string& block_name,
1126                          int n_inputs,
1127                          const CellGrid* grid = nullptr) {
1128     if (debug_level >= 1) {
1129       std::cout << "<saveInputs>: block_name = " << block_name << std::endl;
1130     }
1131     if (!is_first_block_)
1132       (*json_file_) << ", ";
1133     (*json_file_) << "\"" << block_name << "\": [";
1134     if (block_name == "input_tau") {
1135       for (int input_index = 0; input_index < n_inputs; ++input_index) {
1136         float input = inputs.matrix<float>()(0, input_index);
1137         if (input_index != 0)
1138           (*json_file_) << ", ";
1139         (*json_file_) << input;
1140       }
1141     } else {
1142       assert(grid);
1143       int n_eta, n_phi;
1144       if (block_name.find("input_inner") != std::string::npos) {
1145         n_eta = 5;
1146         n_phi = 5;
1147       } else if (block_name.find("input_outer") != std::string::npos) {
1148         n_eta = 10;
1149         n_phi = 10;
1150       } else
1151         assert(0);
1152       int eta_phi_index = 0;
1153       for (int eta = -n_eta; eta <= n_eta; ++eta) {
1154         if (eta != -n_eta)
1155           (*json_file_) << ", ";
1156         (*json_file_) << "[";
1157         for (int phi = -n_phi; phi <= n_phi; ++phi) {
1158           if (phi != -n_phi)
1159             (*json_file_) << ", ";
1160           (*json_file_) << "[";
1161           const CellIndex cell_index{eta, phi};
1162           const auto cell_iter = grid->find(cell_index);
1163           for (int input_index = 0; input_index < n_inputs; ++input_index) {
1164             float input = 0.;
1165             if (cell_iter != grid->end()) {
1166               input = inputs.tensor<float, 4>()(eta_phi_index, 0, 0, input_index);
1167             }
1168             if (input_index != 0)
1169               (*json_file_) << ", ";
1170             (*json_file_) << input;
1171           }
1172           if (cell_iter != grid->end()) {
1173             eta_phi_index += 1;
1174           }
1175           (*json_file_) << "]";
1176         }
1177         (*json_file_) << "]";
1178       }
1179     }
1180     (*json_file_) << "]";
1181     is_first_block_ = false;
1182   }
1183 
1184 private:
1185   tensorflow::Tensor getPredictions(edm::Event& event, edm::Handle<TauCollection> taus) override {
1186     // Empty dummy vectors
1187     const std::vector<pat::Electron> electron_collection_default;
1188     const std::vector<pat::Muon> muon_collection_default;
1189     const reco::TauDiscriminatorContainer basicTauDiscriminators_default;
1190     const reco::TauDiscriminatorContainer basicTauDiscriminatorsdR03_default;
1191     const edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>
1192         pfTauTransverseImpactParameters_default;
1193 
1194     const std::vector<pat::Electron>* electron_collection;
1195     const std::vector<pat::Muon>* muon_collection;
1196     const reco::TauDiscriminatorContainer* basicTauDiscriminators;
1197     const reco::TauDiscriminatorContainer* basicTauDiscriminatorsdR03;
1198     const edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>*
1199         pfTauTransverseImpactParameters;
1200 
1201     if (!is_online_) {
1202       electron_collection = &event.get(electrons_token_);
1203       muon_collection = &event.get(muons_token_);
1204       pfTauTransverseImpactParameters = &pfTauTransverseImpactParameters_default;
1205       basicTauDiscriminators = &basicTauDiscriminators_default;
1206       basicTauDiscriminatorsdR03 = &basicTauDiscriminatorsdR03_default;
1207     } else {
1208       electron_collection = &electron_collection_default;
1209       muon_collection = &muon_collection_default;
1210       pfTauTransverseImpactParameters = &event.get(pfTauTransverseImpactParameters_token_);
1211       basicTauDiscriminators = &event.get(basicTauDiscriminators_inputToken_);
1212       basicTauDiscriminatorsdR03 = &event.get(basicTauDiscriminatorsdR03_inputToken_);
1213 
1214       // Get indices for discriminators
1215       if (!discrIndicesMapped_) {
1216         basicDiscrIndexMap_ =
1217             matchDiscriminatorIndices(event, basicTauDiscriminators_inputToken_, requiredBasicDiscriminators_);
1218         basicDiscrdR03IndexMap_ =
1219             matchDiscriminatorIndices(event, basicTauDiscriminatorsdR03_inputToken_, requiredBasicDiscriminatorsdR03_);
1220         discrIndicesMapped_ = true;
1221       }
1222     }
1223 
1224     TauFunc tauIDs = {basicTauDiscriminators,
1225                       basicTauDiscriminatorsdR03,
1226                       pfTauTransverseImpactParameters,
1227                       basicDiscrIndexMap_,
1228                       basicDiscrdR03IndexMap_};
1229 
1230     edm::Handle<edm::View<reco::Candidate>> pfCands;
1231     event.getByToken(pfcandToken_, pfCands);
1232 
1233     edm::Handle<reco::VertexCollection> vertices;
1234     event.getByToken(vtxToken_, vertices);
1235 
1236     edm::Handle<double> rho;
1237     event.getByToken(rho_token_, rho);
1238 
1239     tensorflow::Tensor predictions(tensorflow::DT_FLOAT, {static_cast<int>(taus->size()), deep_tau::NumberOfOutputs});
1240 
1241     for (size_t tau_index = 0; tau_index < taus->size(); ++tau_index) {
1242       const edm::RefToBase<reco::BaseTau> tauRef = taus->refAt(tau_index);
1243 
1244       std::vector<tensorflow::Tensor> pred_vector;
1245 
1246       bool passesPrediscriminants;
1247       if (is_online_) {
1248         passesPrediscriminants = tauIDs.passPrediscriminants<std::vector<TauDiscInfo<reco::PFTauDiscriminator>>>(
1249             recoPrediscriminants_, andPrediscriminants_, tauRef);
1250       } else {
1251         passesPrediscriminants = tauIDs.passPrediscriminants<std::vector<TauDiscInfo<pat::PATTauDiscriminator>>>(
1252             patPrediscriminants_, andPrediscriminants_, tauRef);
1253       }
1254 
1255       if (passesPrediscriminants) {
1256         if (version_ == 2) {
1257           if (is_online_) {
1258             getPredictionsV2<reco::PFCandidate, reco::PFTau>(taus->at(tau_index),
1259                                                              tau_index,
1260                                                              tauRef,
1261                                                              electron_collection,
1262                                                              muon_collection,
1263                                                              *pfCands,
1264                                                              vertices->at(0),
1265                                                              *rho,
1266                                                              pred_vector,
1267                                                              tauIDs);
1268           } else
1269             getPredictionsV2<pat::PackedCandidate, pat::Tau>(taus->at(tau_index),
1270                                                              tau_index,
1271                                                              tauRef,
1272                                                              electron_collection,
1273                                                              muon_collection,
1274                                                              *pfCands,
1275                                                              vertices->at(0),
1276                                                              *rho,
1277                                                              pred_vector,
1278                                                              tauIDs);
1279         } else {
1280           throw cms::Exception("DeepTauId") << "version " << version_ << " is not supported.";
1281         }
1282 
1283         for (int k = 0; k < deep_tau::NumberOfOutputs; ++k) {
1284           const float pred = pred_vector[0].flat<float>()(k);
1285           if (!(pred >= 0 && pred <= 1))
1286             throw cms::Exception("DeepTauId")
1287                 << "invalid prediction = " << pred << " for tau_index = " << tau_index << ", pred_index = " << k;
1288           predictions.matrix<float>()(tau_index, k) = pred;
1289         }
1290       } else {
1291         // This else statement was added as a part of the DeepTau@HLT development. It does not affect the current state
1292         // of offline DeepTauId code as there the preselection is not used (it was added in the DeepTau@HLT). It returns
1293         // default values for deepTau score if the preselection failed. Before this statement the values given for this tau
1294         // were random. k == 2 corresponds to the tau score and all other k values to e, mu and jets. By defining in this way
1295         // the final score is -1.
1296         for (int k = 0; k < deep_tau::NumberOfOutputs; ++k) {
1297           predictions.matrix<float>()(tau_index, k) = (k == 2) ? -1.f : 2.f;
1298         }
1299       }
1300     }
1301     return predictions;
1302   }
1303 
1304   template <typename CandidateCastType, typename TauCastType>
1305   void getPredictionsV2(TauCollection::const_reference& tau,
1306                         const size_t tau_index,
1307                         const edm::RefToBase<reco::BaseTau> tau_ref,
1308                         const std::vector<pat::Electron>* electrons,
1309                         const std::vector<pat::Muon>* muons,
1310                         const edm::View<reco::Candidate>& pfCands,
1311                         const reco::Vertex& pv,
1312                         double rho,
1313                         std::vector<tensorflow::Tensor>& pred_vector,
1314                         TauFunc tau_funcs) {
1315     using namespace dnn_inputs_v2;
1316     if (debug_level >= 2) {
1317       std::cout << "<DeepTauId::getPredictionsV2 (moduleLabel = " << moduleDescription().moduleLabel()
1318                 << ")>:" << std::endl;
1319       std::cout << " tau: pT = " << tau.pt() << ", eta = " << tau.eta() << ", phi = " << tau.phi() << std::endl;
1320     }
1321     CellGrid inner_grid(number_of_inner_cell, number_of_inner_cell, 0.02, 0.02, disable_CellIndex_workaround_);
1322     CellGrid outer_grid(number_of_outer_cell, number_of_outer_cell, 0.05, 0.05, disable_CellIndex_workaround_);
1323     fillGrids(dynamic_cast<const TauCastType&>(tau), *electrons, inner_grid, outer_grid);
1324     fillGrids(dynamic_cast<const TauCastType&>(tau), *muons, inner_grid, outer_grid);
1325     fillGrids(dynamic_cast<const TauCastType&>(tau), pfCands, inner_grid, outer_grid);
1326 
1327     createTauBlockInputs<CandidateCastType>(
1328         dynamic_cast<const TauCastType&>(tau), tau_index, tau_ref, pv, rho, tau_funcs);
1329     checkInputs(*tauBlockTensor_, "input_tau", static_cast<int>(tauBlockTensor_->shape().dim_size(1)));
1330     createConvFeatures<CandidateCastType>(dynamic_cast<const TauCastType&>(tau),
1331                                           tau_index,
1332                                           tau_ref,
1333                                           pv,
1334                                           rho,
1335                                           electrons,
1336                                           muons,
1337                                           pfCands,
1338                                           inner_grid,
1339                                           tau_funcs,
1340                                           true);
1341     checkInputs(*eGammaTensor_[true], "input_inner_egamma", EgammaBlockInputs::NumberOfInputs, &inner_grid);
1342     checkInputs(*muonTensor_[true], "input_inner_muon", MuonBlockInputs::NumberOfInputs, &inner_grid);
1343     checkInputs(*hadronsTensor_[true], "input_inner_hadrons", HadronBlockInputs::NumberOfInputs, &inner_grid);
1344     createConvFeatures<CandidateCastType>(dynamic_cast<const TauCastType&>(tau),
1345                                           tau_index,
1346                                           tau_ref,
1347                                           pv,
1348                                           rho,
1349                                           electrons,
1350                                           muons,
1351                                           pfCands,
1352                                           outer_grid,
1353                                           tau_funcs,
1354                                           false);
1355     checkInputs(*eGammaTensor_[false], "input_outer_egamma", EgammaBlockInputs::NumberOfInputs, &outer_grid);
1356     checkInputs(*muonTensor_[false], "input_outer_muon", MuonBlockInputs::NumberOfInputs, &outer_grid);
1357     checkInputs(*hadronsTensor_[false], "input_outer_hadrons", HadronBlockInputs::NumberOfInputs, &outer_grid);
1358 
1359     if (save_inputs_) {
1360       std::string json_file_name = "DeepTauId_" + std::to_string(file_counter_) + ".json";
1361       json_file_ = new std::ofstream(json_file_name.data());
1362       is_first_block_ = true;
1363       (*json_file_) << "{";
1364       saveInputs(*tauBlockTensor_, "input_tau", static_cast<int>(tauBlockTensor_->shape().dim_size(1)));
1365       saveInputs(
1366           *eGammaTensor_[true], "input_inner_egamma", dnn_inputs_v2::EgammaBlockInputs::NumberOfInputs, &inner_grid);
1367       saveInputs(*muonTensor_[true], "input_inner_muon", dnn_inputs_v2::MuonBlockInputs::NumberOfInputs, &inner_grid);
1368       saveInputs(
1369           *hadronsTensor_[true], "input_inner_hadrons", dnn_inputs_v2::HadronBlockInputs::NumberOfInputs, &inner_grid);
1370       saveInputs(
1371           *eGammaTensor_[false], "input_outer_egamma", dnn_inputs_v2::EgammaBlockInputs::NumberOfInputs, &outer_grid);
1372       saveInputs(*muonTensor_[false], "input_outer_muon", dnn_inputs_v2::MuonBlockInputs::NumberOfInputs, &outer_grid);
1373       saveInputs(
1374           *hadronsTensor_[false], "input_outer_hadrons", dnn_inputs_v2::HadronBlockInputs::NumberOfInputs, &outer_grid);
1375       (*json_file_) << "}";
1376       delete json_file_;
1377       ++file_counter_;
1378     }
1379 
1380     tensorflow::run(&(cache_->getSession("core")),
1381                     {{"input_tau", *tauBlockTensor_},
1382                      {"input_inner", *convTensor_.at(true)},
1383                      {"input_outer", *convTensor_.at(false)}},
1384                     {"main_output/Softmax"},
1385                     &pred_vector);
1386     if (debug_level >= 1) {
1387       std::cout << "output = { ";
1388       for (int idx = 0; idx < deep_tau::NumberOfOutputs; ++idx) {
1389         if (idx > 0)
1390           std::cout << ", ";
1391         std::string label;
1392         if (idx == 0)
1393           label = "e";
1394         else if (idx == 1)
1395           label = "mu";
1396         else if (idx == 2)
1397           label = "tau";
1398         else if (idx == 3)
1399           label = "jet";
1400         else
1401           assert(0);
1402         std::cout << label << " = " << pred_vector[0].flat<float>()(idx);
1403       }
1404       std::cout << " }" << std::endl;
1405     }
1406   }
1407 
1408   template <typename Collection, typename TauCastType>
1409   void fillGrids(const TauCastType& tau, const Collection& objects, CellGrid& inner_grid, CellGrid& outer_grid) {
1410     static constexpr double outer_dR2 = 0.25;  //0.5^2
1411     const double inner_radius = getInnerSignalConeRadius(tau.polarP4().pt());
1412     const double inner_dR2 = std::pow(inner_radius, 2);
1413 
1414     const auto addObject = [&](size_t n, double deta, double dphi, CellGrid& grid) {
1415       const auto& obj = objects.at(n);
1416       const CellObjectType obj_type = GetCellObjectType(obj);
1417       if (obj_type == CellObjectType::Other)
1418         return;
1419       CellIndex cell_index;
1420       if (grid.tryGetCellIndex(deta, dphi, cell_index)) {
1421         Cell& cell = grid[cell_index];
1422         auto iter = cell.find(obj_type);
1423         if (iter != cell.end()) {
1424           const auto& prev_obj = objects.at(iter->second);
1425           if (obj.polarP4().pt() > prev_obj.polarP4().pt())
1426             iter->second = n;
1427         } else {
1428           cell[obj_type] = n;
1429         }
1430       }
1431     };
1432 
1433     for (size_t n = 0; n < objects.size(); ++n) {
1434       const auto& obj = objects.at(n);
1435       const double deta = obj.polarP4().eta() - tau.polarP4().eta();
1436       const double dphi = reco::deltaPhi(obj.polarP4().phi(), tau.polarP4().phi());
1437       const double dR2 = std::pow(deta, 2) + std::pow(dphi, 2);
1438       if (dR2 < inner_dR2)
1439         addObject(n, deta, dphi, inner_grid);
1440       if (dR2 < outer_dR2)
1441         addObject(n, deta, dphi, outer_grid);
1442     }
1443   }
1444 
1445   tensorflow::Tensor getPartialPredictions(bool is_inner) {
1446     std::vector<tensorflow::Tensor> pred_vector;
1447     if (is_inner) {
1448       tensorflow::run(&(cache_->getSession("inner")),
1449                       {
1450                           {"input_inner_egamma", *eGammaTensor_.at(is_inner)},
1451                           {"input_inner_muon", *muonTensor_.at(is_inner)},
1452                           {"input_inner_hadrons", *hadronsTensor_.at(is_inner)},
1453                       },
1454                       {"inner_all_dropout_4/Identity"},
1455                       &pred_vector);
1456     } else {
1457       tensorflow::run(&(cache_->getSession("outer")),
1458                       {
1459                           {"input_outer_egamma", *eGammaTensor_.at(is_inner)},
1460                           {"input_outer_muon", *muonTensor_.at(is_inner)},
1461                           {"input_outer_hadrons", *hadronsTensor_.at(is_inner)},
1462                       },
1463                       {"outer_all_dropout_4/Identity"},
1464                       &pred_vector);
1465     }
1466     return pred_vector.at(0);
1467   }
1468 
1469   template <typename CandidateCastType, typename TauCastType>
1470   void createConvFeatures(const TauCastType& tau,
1471                           const size_t tau_index,
1472                           const edm::RefToBase<reco::BaseTau> tau_ref,
1473                           const reco::Vertex& pv,
1474                           double rho,
1475                           const std::vector<pat::Electron>* electrons,
1476                           const std::vector<pat::Muon>* muons,
1477                           const edm::View<reco::Candidate>& pfCands,
1478                           const CellGrid& grid,
1479                           TauFunc tau_funcs,
1480                           bool is_inner) {
1481     if (debug_level >= 2) {
1482       std::cout << "<DeepTauId::createConvFeatures (is_inner = " << is_inner << ")>:" << std::endl;
1483     }
1484     tensorflow::Tensor& convTensor = *convTensor_.at(is_inner);
1485     eGammaTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
1486         tensorflow::DT_FLOAT,
1487         tensorflow::TensorShape{
1488             (long long int)grid.num_valid_cells(), 1, 1, dnn_inputs_v2::EgammaBlockInputs::NumberOfInputs});
1489     muonTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
1490         tensorflow::DT_FLOAT,
1491         tensorflow::TensorShape{
1492             (long long int)grid.num_valid_cells(), 1, 1, dnn_inputs_v2::MuonBlockInputs::NumberOfInputs});
1493     hadronsTensor_[is_inner] = std::make_unique<tensorflow::Tensor>(
1494         tensorflow::DT_FLOAT,
1495         tensorflow::TensorShape{
1496             (long long int)grid.num_valid_cells(), 1, 1, dnn_inputs_v2::HadronBlockInputs::NumberOfInputs});
1497 
1498     eGammaTensor_[is_inner]->flat<float>().setZero();
1499     muonTensor_[is_inner]->flat<float>().setZero();
1500     hadronsTensor_[is_inner]->flat<float>().setZero();
1501 
1502     unsigned idx = 0;
1503     for (int eta = -grid.maxEtaIndex(); eta <= grid.maxEtaIndex(); ++eta) {
1504       for (int phi = -grid.maxPhiIndex(); phi <= grid.maxPhiIndex(); ++phi) {
1505         if (debug_level >= 2) {
1506           std::cout << "processing ( eta = " << eta << ", phi = " << phi << " )" << std::endl;
1507         }
1508         const CellIndex cell_index{eta, phi};
1509         const auto cell_iter = grid.find(cell_index);
1510         if (cell_iter != grid.end()) {
1511           if (debug_level >= 2) {
1512             std::cout << " creating inputs for ( eta = " << eta << ", phi = " << phi << " ): idx = " << idx
1513                       << std::endl;
1514           }
1515           const Cell& cell = cell_iter->second;
1516           createEgammaBlockInputs<CandidateCastType>(
1517               idx, tau, tau_index, tau_ref, pv, rho, electrons, pfCands, cell, tau_funcs, is_inner);
1518           createMuonBlockInputs<CandidateCastType>(
1519               idx, tau, tau_index, tau_ref, pv, rho, muons, pfCands, cell, tau_funcs, is_inner);
1520           createHadronsBlockInputs<CandidateCastType>(
1521               idx, tau, tau_index, tau_ref, pv, rho, pfCands, cell, tau_funcs, is_inner);
1522           idx += 1;
1523         } else {
1524           if (debug_level >= 2) {
1525             std::cout << " skipping creation of inputs, because ( eta = " << eta << ", phi = " << phi
1526                       << " ) is not in the grid !!" << std::endl;
1527           }
1528         }
1529       }
1530     }
1531 
1532     const auto predTensor = getPartialPredictions(is_inner);
1533     idx = 0;
1534     for (int eta = -grid.maxEtaIndex(); eta <= grid.maxEtaIndex(); ++eta) {
1535       for (int phi = -grid.maxPhiIndex(); phi <= grid.maxPhiIndex(); ++phi) {
1536         const CellIndex cell_index{eta, phi};
1537         const int eta_index = grid.getEtaTensorIndex(cell_index);
1538         const int phi_index = grid.getPhiTensorIndex(cell_index);
1539 
1540         const auto cell_iter = grid.find(cell_index);
1541         if (cell_iter != grid.end()) {
1542           setCellConvFeatures(convTensor, predTensor, idx, eta_index, phi_index);
1543           idx += 1;
1544         } else {
1545           setCellConvFeatures(convTensor, *zeroOutputTensor_[is_inner], 0, eta_index, phi_index);
1546         }
1547       }
1548     }
1549   }
1550 
1551   void setCellConvFeatures(tensorflow::Tensor& convTensor,
1552                            const tensorflow::Tensor& features,
1553                            unsigned batch_idx,
1554                            int eta_index,
1555                            int phi_index) {
1556     for (int n = 0; n < dnn_inputs_v2::number_of_conv_features; ++n) {
1557       convTensor.tensor<float, 4>()(0, eta_index, phi_index, n) = features.tensor<float, 4>()(batch_idx, 0, 0, n);
1558     }
1559   }
1560 
1561   template <typename CandidateCastType, typename TauCastType>
1562   void createTauBlockInputs(const TauCastType& tau,
1563                             const size_t& tau_index,
1564                             const edm::RefToBase<reco::BaseTau> tau_ref,
1565                             const reco::Vertex& pv,
1566                             double rho,
1567                             TauFunc tau_funcs) {
1568     namespace dnn = dnn_inputs_v2::TauBlockInputs;
1569     namespace sc = deep_tau::Scaling;
1570     sc::FeatureT ft = sc::FeatureT::TauFlat;
1571     const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft, false));
1572 
1573     tensorflow::Tensor& inputs = *tauBlockTensor_;
1574     inputs.flat<float>().setZero();
1575 
1576     const auto& get = [&](int var_index) -> float& {
1577       return inputs.matrix<float>()(0, tauInputs_indices_.at(var_index));
1578     };
1579 
1580     auto leadChargedHadrCand = dynamic_cast<const CandidateCastType*>(tau.leadChargedHadrCand().get());
1581 
1582     get(dnn::rho) = sp.scale(rho, tauInputs_indices_[dnn::rho]);
1583     get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), tauInputs_indices_[dnn::tau_pt]);
1584     get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), tauInputs_indices_[dnn::tau_eta]);
1585     if (sub_version_ == 1) {
1586       get(dnn::tau_phi) = getValueLinear(tau.polarP4().phi(), -pi, pi, false);
1587     }
1588     get(dnn::tau_mass) = sp.scale(tau.polarP4().mass(), tauInputs_indices_[dnn::tau_mass]);
1589     get(dnn::tau_E_over_pt) = sp.scale(tau.p4().energy() / tau.p4().pt(), tauInputs_indices_[dnn::tau_E_over_pt]);
1590     get(dnn::tau_charge) = sp.scale(tau.charge(), tauInputs_indices_[dnn::tau_charge]);
1591     get(dnn::tau_n_charged_prongs) = sp.scale(tau.decayMode() / 5 + 1, tauInputs_indices_[dnn::tau_n_charged_prongs]);
1592     get(dnn::tau_n_neutral_prongs) = sp.scale(tau.decayMode() % 5, tauInputs_indices_[dnn::tau_n_neutral_prongs]);
1593     get(dnn::chargedIsoPtSum) =
1594         sp.scale(tau_funcs.getChargedIsoPtSum(tau, tau_ref), tauInputs_indices_[dnn::chargedIsoPtSum]);
1595     get(dnn::chargedIsoPtSumdR03_over_dR05) =
1596         sp.scale(tau_funcs.getChargedIsoPtSumdR03(tau, tau_ref) / tau_funcs.getChargedIsoPtSum(tau, tau_ref),
1597                  tauInputs_indices_[dnn::chargedIsoPtSumdR03_over_dR05]);
1598     if (sub_version_ == 1)
1599       get(dnn::footprintCorrection) =
1600           sp.scale(tau_funcs.getFootprintCorrectiondR03(tau, tau_ref), tauInputs_indices_[dnn::footprintCorrection]);
1601     else if (sub_version_ == 5)
1602       get(dnn::footprintCorrection) =
1603           sp.scale(tau_funcs.getFootprintCorrection(tau, tau_ref), tauInputs_indices_[dnn::footprintCorrection]);
1604 
1605     get(dnn::neutralIsoPtSum) =
1606         sp.scale(tau_funcs.getNeutralIsoPtSum(tau, tau_ref), tauInputs_indices_[dnn::neutralIsoPtSum]);
1607     get(dnn::neutralIsoPtSumWeight_over_neutralIsoPtSum) =
1608         sp.scale(tau_funcs.getNeutralIsoPtSumWeight(tau, tau_ref) / tau_funcs.getNeutralIsoPtSum(tau, tau_ref),
1609                  tauInputs_indices_[dnn::neutralIsoPtSumWeight_over_neutralIsoPtSum]);
1610     get(dnn::neutralIsoPtSumWeightdR03_over_neutralIsoPtSum) =
1611         sp.scale(tau_funcs.getNeutralIsoPtSumdR03Weight(tau, tau_ref) / tau_funcs.getNeutralIsoPtSum(tau, tau_ref),
1612                  tauInputs_indices_[dnn::neutralIsoPtSumWeightdR03_over_neutralIsoPtSum]);
1613     get(dnn::neutralIsoPtSumdR03_over_dR05) =
1614         sp.scale(tau_funcs.getNeutralIsoPtSumdR03(tau, tau_ref) / tau_funcs.getNeutralIsoPtSum(tau, tau_ref),
1615                  tauInputs_indices_[dnn::neutralIsoPtSumdR03_over_dR05]);
1616     get(dnn::photonPtSumOutsideSignalCone) = sp.scale(tau_funcs.getPhotonPtSumOutsideSignalCone(tau, tau_ref),
1617                                                       tauInputs_indices_[dnn::photonPtSumOutsideSignalCone]);
1618     get(dnn::puCorrPtSum) = sp.scale(tau_funcs.getPuCorrPtSum(tau, tau_ref), tauInputs_indices_[dnn::puCorrPtSum]);
1619     // The global PCA coordinates were used as inputs during the NN training, but it was decided to disable
1620     // them for the inference, because modeling of dxy_PCA in MC poorly describes the data, and x and y coordinates
1621     // in data results outside of the expected 5 std. dev. input validity range. On the other hand,
1622     // these coordinates are strongly era-dependent. Kept as comment to document what NN expects.
1623     if (sub_version_ == 1) {
1624       if (!disable_dxy_pca_) {
1625         auto const pca = tau_funcs.getdxyPCA(tau, tau_index);
1626         get(dnn::tau_dxy_pca_x) = sp.scale(pca.x(), tauInputs_indices_[dnn::tau_dxy_pca_x]);
1627         get(dnn::tau_dxy_pca_y) = sp.scale(pca.y(), tauInputs_indices_[dnn::tau_dxy_pca_y]);
1628         get(dnn::tau_dxy_pca_z) = sp.scale(pca.z(), tauInputs_indices_[dnn::tau_dxy_pca_z]);
1629       } else {
1630         get(dnn::tau_dxy_pca_x) = 0;
1631         get(dnn::tau_dxy_pca_y) = 0;
1632         get(dnn::tau_dxy_pca_z) = 0;
1633       }
1634     }
1635 
1636     const bool tau_dxy_valid =
1637         isAbove(tau_funcs.getdxy(tau, tau_index), -10) && isAbove(tau_funcs.getdxyError(tau, tau_index), 0);
1638     if (tau_dxy_valid) {
1639       get(dnn::tau_dxy_valid) = sp.scale(tau_dxy_valid, tauInputs_indices_[dnn::tau_dxy_valid]);
1640       get(dnn::tau_dxy) = sp.scale(tau_funcs.getdxy(tau, tau_index), tauInputs_indices_[dnn::tau_dxy]);
1641       get(dnn::tau_dxy_sig) =
1642           sp.scale(std::abs(tau_funcs.getdxy(tau, tau_index)) / tau_funcs.getdxyError(tau, tau_index),
1643                    tauInputs_indices_[dnn::tau_dxy_sig]);
1644     }
1645     const bool tau_ip3d_valid =
1646         isAbove(tau_funcs.getip3d(tau, tau_index), -10) && isAbove(tau_funcs.getip3dError(tau, tau_index), 0);
1647     if (tau_ip3d_valid) {
1648       get(dnn::tau_ip3d_valid) = sp.scale(tau_ip3d_valid, tauInputs_indices_[dnn::tau_ip3d_valid]);
1649       get(dnn::tau_ip3d) = sp.scale(tau_funcs.getip3d(tau, tau_index), tauInputs_indices_[dnn::tau_ip3d]);
1650       get(dnn::tau_ip3d_sig) =
1651           sp.scale(std::abs(tau_funcs.getip3d(tau, tau_index)) / tau_funcs.getip3dError(tau, tau_index),
1652                    tauInputs_indices_[dnn::tau_ip3d_sig]);
1653     }
1654     if (leadChargedHadrCand) {
1655       const bool hasTrackDetails = candFunc::getHasTrackDetails(*leadChargedHadrCand);
1656       const float tau_dz = (is_online_ && !hasTrackDetails) ? 0 : candFunc::getTauDz(*leadChargedHadrCand);
1657       get(dnn::tau_dz) = sp.scale(tau_dz, tauInputs_indices_[dnn::tau_dz]);
1658       get(dnn::tau_dz_sig_valid) =
1659           sp.scale(candFunc::getTauDZSigValid(*leadChargedHadrCand), tauInputs_indices_[dnn::tau_dz_sig_valid]);
1660       const double dzError = hasTrackDetails ? leadChargedHadrCand->dzError() : -999.;
1661       get(dnn::tau_dz_sig) = sp.scale(std::abs(tau_dz) / dzError, tauInputs_indices_[dnn::tau_dz_sig]);
1662     }
1663     get(dnn::tau_flightLength_x) =
1664         sp.scale(tau_funcs.getFlightLength(tau, tau_index).x(), tauInputs_indices_[dnn::tau_flightLength_x]);
1665     get(dnn::tau_flightLength_y) =
1666         sp.scale(tau_funcs.getFlightLength(tau, tau_index).y(), tauInputs_indices_[dnn::tau_flightLength_y]);
1667     get(dnn::tau_flightLength_z) =
1668         sp.scale(tau_funcs.getFlightLength(tau, tau_index).z(), tauInputs_indices_[dnn::tau_flightLength_z]);
1669     if (sub_version_ == 1)
1670       get(dnn::tau_flightLength_sig) = 0.55756444;  //This value is set due to a bug in the training
1671     else if (sub_version_ == 5)
1672       get(dnn::tau_flightLength_sig) =
1673           sp.scale(tau_funcs.getFlightLengthSig(tau, tau_index), tauInputs_indices_[dnn::tau_flightLength_sig]);
1674 
1675     get(dnn::tau_pt_weighted_deta_strip) = sp.scale(reco::tau::pt_weighted_deta_strip(tau, tau.decayMode()),
1676                                                     tauInputs_indices_[dnn::tau_pt_weighted_deta_strip]);
1677 
1678     get(dnn::tau_pt_weighted_dphi_strip) = sp.scale(reco::tau::pt_weighted_dphi_strip(tau, tau.decayMode()),
1679                                                     tauInputs_indices_[dnn::tau_pt_weighted_dphi_strip]);
1680     get(dnn::tau_pt_weighted_dr_signal) = sp.scale(reco::tau::pt_weighted_dr_signal(tau, tau.decayMode()),
1681                                                    tauInputs_indices_[dnn::tau_pt_weighted_dr_signal]);
1682     get(dnn::tau_pt_weighted_dr_iso) =
1683         sp.scale(reco::tau::pt_weighted_dr_iso(tau, tau.decayMode()), tauInputs_indices_[dnn::tau_pt_weighted_dr_iso]);
1684     get(dnn::tau_leadingTrackNormChi2) =
1685         sp.scale(tau_funcs.getLeadingTrackNormChi2(tau), tauInputs_indices_[dnn::tau_leadingTrackNormChi2]);
1686     const auto eratio = reco::tau::eratio(tau);
1687     const bool tau_e_ratio_valid = std::isnormal(eratio) && eratio > 0.f;
1688     get(dnn::tau_e_ratio_valid) = sp.scale(tau_e_ratio_valid, tauInputs_indices_[dnn::tau_e_ratio_valid]);
1689     get(dnn::tau_e_ratio) = tau_e_ratio_valid ? sp.scale(eratio, tauInputs_indices_[dnn::tau_e_ratio]) : 0.f;
1690     const double gj_angle_diff = calculateGottfriedJacksonAngleDifference(tau, tau_index, tau_funcs);
1691     const bool tau_gj_angle_diff_valid = (std::isnormal(gj_angle_diff) || gj_angle_diff == 0) && gj_angle_diff >= 0;
1692     get(dnn::tau_gj_angle_diff_valid) =
1693         sp.scale(tau_gj_angle_diff_valid, tauInputs_indices_[dnn::tau_gj_angle_diff_valid]);
1694     get(dnn::tau_gj_angle_diff) =
1695         tau_gj_angle_diff_valid ? sp.scale(gj_angle_diff, tauInputs_indices_[dnn::tau_gj_angle_diff]) : 0;
1696     get(dnn::tau_n_photons) = sp.scale(reco::tau::n_photons_total(tau), tauInputs_indices_[dnn::tau_n_photons]);
1697     get(dnn::tau_emFraction) = sp.scale(tau_funcs.getEmFraction(tau), tauInputs_indices_[dnn::tau_emFraction]);
1698 
1699     get(dnn::tau_inside_ecal_crack) =
1700         sp.scale(isInEcalCrack(tau.p4().eta()), tauInputs_indices_[dnn::tau_inside_ecal_crack]);
1701     get(dnn::leadChargedCand_etaAtEcalEntrance_minus_tau_eta) =
1702         sp.scale(tau_funcs.getEtaAtEcalEntrance(tau) - tau.p4().eta(),
1703                  tauInputs_indices_[dnn::leadChargedCand_etaAtEcalEntrance_minus_tau_eta]);
1704   }
1705 
1706   template <typename CandidateCastType, typename TauCastType>
1707   void createEgammaBlockInputs(unsigned idx,
1708                                const TauCastType& tau,
1709                                const size_t tau_index,
1710                                const edm::RefToBase<reco::BaseTau> tau_ref,
1711                                const reco::Vertex& pv,
1712                                double rho,
1713                                const std::vector<pat::Electron>* electrons,
1714                                const edm::View<reco::Candidate>& pfCands,
1715                                const Cell& cell_map,
1716                                TauFunc tau_funcs,
1717                                bool is_inner) {
1718     namespace dnn = dnn_inputs_v2::EgammaBlockInputs;
1719     namespace sc = deep_tau::Scaling;
1720     sc::FeatureT ft_global = sc::FeatureT::GridGlobal;
1721     sc::FeatureT ft_PFe = sc::FeatureT::PfCand_electron;
1722     sc::FeatureT ft_PFg = sc::FeatureT::PfCand_gamma;
1723     sc::FeatureT ft_e = sc::FeatureT::Electron;
1724 
1725     // needed to remap indices from scaling vectors to those from dnn_inputs_v2::EgammaBlockInputs
1726     int PFe_index_offset = scalingParamsMap_->at(std::make_pair(ft_global, false)).mean_.size();
1727     int e_index_offset = PFe_index_offset + scalingParamsMap_->at(std::make_pair(ft_PFe, false)).mean_.size();
1728     int PFg_index_offset = e_index_offset + scalingParamsMap_->at(std::make_pair(ft_e, false)).mean_.size();
1729 
1730     // to account for swapped order of PfCand_gamma and Electron blocks for v2p5 training w.r.t. v2p1
1731     int fill_index_offset_e = 0;
1732     int fill_index_offset_PFg = 0;
1733     if (sub_version_ == 5) {
1734       fill_index_offset_e =
1735           scalingParamsMap_->at(std::make_pair(ft_PFg, false)).mean_.size();  // size of PF gamma features
1736       fill_index_offset_PFg =
1737           -scalingParamsMap_->at(std::make_pair(ft_e, false)).mean_.size();  // size of Electron features
1738     }
1739 
1740     tensorflow::Tensor& inputs = *eGammaTensor_.at(is_inner);
1741 
1742     const auto& get = [&](int var_index) -> float& { return inputs.tensor<float, 4>()(idx, 0, 0, var_index); };
1743 
1744     const bool valid_index_pf_ele = cell_map.count(CellObjectType::PfCand_electron);
1745     const bool valid_index_pf_gamma = cell_map.count(CellObjectType::PfCand_gamma);
1746     const bool valid_index_ele = cell_map.count(CellObjectType::Electron);
1747 
1748     if (!cell_map.empty()) {
1749       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_global, false));
1750       get(dnn::rho) = sp.scale(rho, dnn::rho);
1751       get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), dnn::tau_pt);
1752       get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), dnn::tau_eta);
1753       get(dnn::tau_inside_ecal_crack) = sp.scale(isInEcalCrack(tau.polarP4().eta()), dnn::tau_inside_ecal_crack);
1754     }
1755     if (valid_index_pf_ele) {
1756       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFe, is_inner));
1757       size_t index_pf_ele = cell_map.at(CellObjectType::PfCand_electron);
1758       const auto& ele_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_pf_ele));
1759 
1760       get(dnn::pfCand_ele_valid) = sp.scale(valid_index_pf_ele, dnn::pfCand_ele_valid - PFe_index_offset);
1761       get(dnn::pfCand_ele_rel_pt) =
1762           sp.scale(ele_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_ele_rel_pt - PFe_index_offset);
1763       get(dnn::pfCand_ele_deta) =
1764           sp.scale(ele_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_ele_deta - PFe_index_offset);
1765       get(dnn::pfCand_ele_dphi) =
1766           sp.scale(dPhi(tau.polarP4(), ele_cand.polarP4()), dnn::pfCand_ele_dphi - PFe_index_offset);
1767       get(dnn::pfCand_ele_pvAssociationQuality) = sp.scale<int>(
1768           candFunc::getPvAssocationQuality(ele_cand), dnn::pfCand_ele_pvAssociationQuality - PFe_index_offset);
1769       get(dnn::pfCand_ele_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(ele_cand, 0.9906834f),
1770                                                              dnn::pfCand_ele_puppiWeight - PFe_index_offset)
1771                                                   : sp.scale(candFunc::getPuppiWeight(ele_cand, 0.9669586f),
1772                                                              dnn::pfCand_ele_puppiWeight - PFe_index_offset);
1773       get(dnn::pfCand_ele_charge) = sp.scale(ele_cand.charge(), dnn::pfCand_ele_charge - PFe_index_offset);
1774       get(dnn::pfCand_ele_lostInnerHits) =
1775           sp.scale<int>(candFunc::getLostInnerHits(ele_cand, 0), dnn::pfCand_ele_lostInnerHits - PFe_index_offset);
1776       get(dnn::pfCand_ele_numberOfPixelHits) =
1777           sp.scale(candFunc::getNumberOfPixelHits(ele_cand, 0), dnn::pfCand_ele_numberOfPixelHits - PFe_index_offset);
1778       get(dnn::pfCand_ele_vertex_dx) =
1779           sp.scale(ele_cand.vertex().x() - pv.position().x(), dnn::pfCand_ele_vertex_dx - PFe_index_offset);
1780       get(dnn::pfCand_ele_vertex_dy) =
1781           sp.scale(ele_cand.vertex().y() - pv.position().y(), dnn::pfCand_ele_vertex_dy - PFe_index_offset);
1782       get(dnn::pfCand_ele_vertex_dz) =
1783           sp.scale(ele_cand.vertex().z() - pv.position().z(), dnn::pfCand_ele_vertex_dz - PFe_index_offset);
1784       get(dnn::pfCand_ele_vertex_dx_tauFL) =
1785           sp.scale(ele_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
1786                    dnn::pfCand_ele_vertex_dx_tauFL - PFe_index_offset);
1787       get(dnn::pfCand_ele_vertex_dy_tauFL) =
1788           sp.scale(ele_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
1789                    dnn::pfCand_ele_vertex_dy_tauFL - PFe_index_offset);
1790       get(dnn::pfCand_ele_vertex_dz_tauFL) =
1791           sp.scale(ele_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
1792                    dnn::pfCand_ele_vertex_dz_tauFL - PFe_index_offset);
1793 
1794       const bool hasTrackDetails = candFunc::getHasTrackDetails(ele_cand);
1795       if (hasTrackDetails) {
1796         get(dnn::pfCand_ele_hasTrackDetails) =
1797             sp.scale(hasTrackDetails, dnn::pfCand_ele_hasTrackDetails - PFe_index_offset);
1798         get(dnn::pfCand_ele_dxy) = sp.scale(candFunc::getTauDxy(ele_cand), dnn::pfCand_ele_dxy - PFe_index_offset);
1799         get(dnn::pfCand_ele_dxy_sig) = sp.scale(std::abs(candFunc::getTauDxy(ele_cand)) / ele_cand.dxyError(),
1800                                                 dnn::pfCand_ele_dxy_sig - PFe_index_offset);
1801         get(dnn::pfCand_ele_dz) = sp.scale(candFunc::getTauDz(ele_cand), dnn::pfCand_ele_dz - PFe_index_offset);
1802         get(dnn::pfCand_ele_dz_sig) = sp.scale(std::abs(candFunc::getTauDz(ele_cand)) / ele_cand.dzError(),
1803                                                dnn::pfCand_ele_dz_sig - PFe_index_offset);
1804         get(dnn::pfCand_ele_track_chi2_ndof) =
1805             candFunc::getPseudoTrack(ele_cand).ndof() > 0
1806                 ? sp.scale(candFunc::getPseudoTrack(ele_cand).chi2() / candFunc::getPseudoTrack(ele_cand).ndof(),
1807                            dnn::pfCand_ele_track_chi2_ndof - PFe_index_offset)
1808                 : 0;
1809         get(dnn::pfCand_ele_track_ndof) =
1810             candFunc::getPseudoTrack(ele_cand).ndof() > 0
1811                 ? sp.scale(candFunc::getPseudoTrack(ele_cand).ndof(), dnn::pfCand_ele_track_ndof - PFe_index_offset)
1812                 : 0;
1813       }
1814     }
1815     if (valid_index_pf_gamma) {
1816       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFg, is_inner));
1817       size_t index_pf_gamma = cell_map.at(CellObjectType::PfCand_gamma);
1818       const auto& gamma_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_pf_gamma));
1819 
1820       get(dnn::pfCand_gamma_valid + fill_index_offset_PFg) =
1821           sp.scale(valid_index_pf_gamma, dnn::pfCand_gamma_valid - PFg_index_offset);
1822       get(dnn::pfCand_gamma_rel_pt + fill_index_offset_PFg) =
1823           sp.scale(gamma_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_gamma_rel_pt - PFg_index_offset);
1824       get(dnn::pfCand_gamma_deta + fill_index_offset_PFg) =
1825           sp.scale(gamma_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_gamma_deta - PFg_index_offset);
1826       get(dnn::pfCand_gamma_dphi + fill_index_offset_PFg) =
1827           sp.scale(dPhi(tau.polarP4(), gamma_cand.polarP4()), dnn::pfCand_gamma_dphi - PFg_index_offset);
1828       get(dnn::pfCand_gamma_pvAssociationQuality + fill_index_offset_PFg) = sp.scale<int>(
1829           candFunc::getPvAssocationQuality(gamma_cand), dnn::pfCand_gamma_pvAssociationQuality - PFg_index_offset);
1830       get(dnn::pfCand_gamma_fromPV + fill_index_offset_PFg) =
1831           sp.scale<int>(candFunc::getFromPV(gamma_cand), dnn::pfCand_gamma_fromPV - PFg_index_offset);
1832       get(dnn::pfCand_gamma_puppiWeight + fill_index_offset_PFg) =
1833           is_inner ? sp.scale(candFunc::getPuppiWeight(gamma_cand, 0.9084110f),
1834                               dnn::pfCand_gamma_puppiWeight - PFg_index_offset)
1835                    : sp.scale(candFunc::getPuppiWeight(gamma_cand, 0.4211567f),
1836                               dnn::pfCand_gamma_puppiWeight - PFg_index_offset);
1837       get(dnn::pfCand_gamma_puppiWeightNoLep + fill_index_offset_PFg) =
1838           is_inner ? sp.scale(candFunc::getPuppiWeightNoLep(gamma_cand, 0.8857716f),
1839                               dnn::pfCand_gamma_puppiWeightNoLep - PFg_index_offset)
1840                    : sp.scale(candFunc::getPuppiWeightNoLep(gamma_cand, 0.3822604f),
1841                               dnn::pfCand_gamma_puppiWeightNoLep - PFg_index_offset);
1842       get(dnn::pfCand_gamma_lostInnerHits + fill_index_offset_PFg) =
1843           sp.scale<int>(candFunc::getLostInnerHits(gamma_cand, 0), dnn::pfCand_gamma_lostInnerHits - PFg_index_offset);
1844       get(dnn::pfCand_gamma_numberOfPixelHits + fill_index_offset_PFg) = sp.scale(
1845           candFunc::getNumberOfPixelHits(gamma_cand, 0), dnn::pfCand_gamma_numberOfPixelHits - PFg_index_offset);
1846       get(dnn::pfCand_gamma_vertex_dx + fill_index_offset_PFg) =
1847           sp.scale(gamma_cand.vertex().x() - pv.position().x(), dnn::pfCand_gamma_vertex_dx - PFg_index_offset);
1848       get(dnn::pfCand_gamma_vertex_dy + fill_index_offset_PFg) =
1849           sp.scale(gamma_cand.vertex().y() - pv.position().y(), dnn::pfCand_gamma_vertex_dy - PFg_index_offset);
1850       get(dnn::pfCand_gamma_vertex_dz + fill_index_offset_PFg) =
1851           sp.scale(gamma_cand.vertex().z() - pv.position().z(), dnn::pfCand_gamma_vertex_dz - PFg_index_offset);
1852       get(dnn::pfCand_gamma_vertex_dx_tauFL + fill_index_offset_PFg) =
1853           sp.scale(gamma_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
1854                    dnn::pfCand_gamma_vertex_dx_tauFL - PFg_index_offset);
1855       get(dnn::pfCand_gamma_vertex_dy_tauFL + fill_index_offset_PFg) =
1856           sp.scale(gamma_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
1857                    dnn::pfCand_gamma_vertex_dy_tauFL - PFg_index_offset);
1858       get(dnn::pfCand_gamma_vertex_dz_tauFL + fill_index_offset_PFg) =
1859           sp.scale(gamma_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
1860                    dnn::pfCand_gamma_vertex_dz_tauFL - PFg_index_offset);
1861       const bool hasTrackDetails = candFunc::getHasTrackDetails(gamma_cand);
1862       if (hasTrackDetails) {
1863         get(dnn::pfCand_gamma_hasTrackDetails + fill_index_offset_PFg) =
1864             sp.scale(hasTrackDetails, dnn::pfCand_gamma_hasTrackDetails - PFg_index_offset);
1865         get(dnn::pfCand_gamma_dxy + fill_index_offset_PFg) =
1866             sp.scale(candFunc::getTauDxy(gamma_cand), dnn::pfCand_gamma_dxy - PFg_index_offset);
1867         get(dnn::pfCand_gamma_dxy_sig + fill_index_offset_PFg) =
1868             sp.scale(std::abs(candFunc::getTauDxy(gamma_cand)) / gamma_cand.dxyError(),
1869                      dnn::pfCand_gamma_dxy_sig - PFg_index_offset);
1870         get(dnn::pfCand_gamma_dz + fill_index_offset_PFg) =
1871             sp.scale(candFunc::getTauDz(gamma_cand), dnn::pfCand_gamma_dz - PFg_index_offset);
1872         get(dnn::pfCand_gamma_dz_sig + fill_index_offset_PFg) =
1873             sp.scale(std::abs(candFunc::getTauDz(gamma_cand)) / gamma_cand.dzError(),
1874                      dnn::pfCand_gamma_dz_sig - PFg_index_offset);
1875         get(dnn::pfCand_gamma_track_chi2_ndof + fill_index_offset_PFg) =
1876             candFunc::getPseudoTrack(gamma_cand).ndof() > 0
1877                 ? sp.scale(candFunc::getPseudoTrack(gamma_cand).chi2() / candFunc::getPseudoTrack(gamma_cand).ndof(),
1878                            dnn::pfCand_gamma_track_chi2_ndof - PFg_index_offset)
1879                 : 0;
1880         get(dnn::pfCand_gamma_track_ndof + fill_index_offset_PFg) =
1881             candFunc::getPseudoTrack(gamma_cand).ndof() > 0
1882                 ? sp.scale(candFunc::getPseudoTrack(gamma_cand).ndof(), dnn::pfCand_gamma_track_ndof - PFg_index_offset)
1883                 : 0;
1884       }
1885     }
1886     if (valid_index_ele) {
1887       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_e, is_inner));
1888       size_t index_ele = cell_map.at(CellObjectType::Electron);
1889       const auto& ele = electrons->at(index_ele);
1890 
1891       get(dnn::ele_valid + fill_index_offset_e) = sp.scale(valid_index_ele, dnn::ele_valid - e_index_offset);
1892       get(dnn::ele_rel_pt + fill_index_offset_e) =
1893           sp.scale(ele.polarP4().pt() / tau.polarP4().pt(), dnn::ele_rel_pt - e_index_offset);
1894       get(dnn::ele_deta + fill_index_offset_e) =
1895           sp.scale(ele.polarP4().eta() - tau.polarP4().eta(), dnn::ele_deta - e_index_offset);
1896       get(dnn::ele_dphi + fill_index_offset_e) =
1897           sp.scale(dPhi(tau.polarP4(), ele.polarP4()), dnn::ele_dphi - e_index_offset);
1898 
1899       float cc_ele_energy, cc_gamma_energy;
1900       int cc_n_gamma;
1901       const bool cc_valid = calculateElectronClusterVarsV2(ele, cc_ele_energy, cc_gamma_energy, cc_n_gamma);
1902       if (cc_valid) {
1903         get(dnn::ele_cc_valid + fill_index_offset_e) = sp.scale(cc_valid, dnn::ele_cc_valid - e_index_offset);
1904         get(dnn::ele_cc_ele_rel_energy + fill_index_offset_e) =
1905             sp.scale(cc_ele_energy / ele.polarP4().pt(), dnn::ele_cc_ele_rel_energy - e_index_offset);
1906         get(dnn::ele_cc_gamma_rel_energy + fill_index_offset_e) =
1907             sp.scale(cc_gamma_energy / cc_ele_energy, dnn::ele_cc_gamma_rel_energy - e_index_offset);
1908         get(dnn::ele_cc_n_gamma + fill_index_offset_e) = sp.scale(cc_n_gamma, dnn::ele_cc_n_gamma - e_index_offset);
1909       }
1910       get(dnn::ele_rel_trackMomentumAtVtx + fill_index_offset_e) =
1911           sp.scale(ele.trackMomentumAtVtx().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumAtVtx - e_index_offset);
1912       get(dnn::ele_rel_trackMomentumAtCalo + fill_index_offset_e) = sp.scale(
1913           ele.trackMomentumAtCalo().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumAtCalo - e_index_offset);
1914       get(dnn::ele_rel_trackMomentumOut + fill_index_offset_e) =
1915           sp.scale(ele.trackMomentumOut().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumOut - e_index_offset);
1916       get(dnn::ele_rel_trackMomentumAtEleClus + fill_index_offset_e) = sp.scale(
1917           ele.trackMomentumAtEleClus().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumAtEleClus - e_index_offset);
1918       get(dnn::ele_rel_trackMomentumAtVtxWithConstraint + fill_index_offset_e) =
1919           sp.scale(ele.trackMomentumAtVtxWithConstraint().R() / ele.polarP4().pt(),
1920                    dnn::ele_rel_trackMomentumAtVtxWithConstraint - e_index_offset);
1921       get(dnn::ele_rel_ecalEnergy + fill_index_offset_e) =
1922           sp.scale(ele.ecalEnergy() / ele.polarP4().pt(), dnn::ele_rel_ecalEnergy - e_index_offset);
1923       get(dnn::ele_ecalEnergy_sig + fill_index_offset_e) =
1924           sp.scale(ele.ecalEnergy() / ele.ecalEnergyError(), dnn::ele_ecalEnergy_sig - e_index_offset);
1925       get(dnn::ele_eSuperClusterOverP + fill_index_offset_e) =
1926           sp.scale(ele.eSuperClusterOverP(), dnn::ele_eSuperClusterOverP - e_index_offset);
1927       get(dnn::ele_eSeedClusterOverP + fill_index_offset_e) =
1928           sp.scale(ele.eSeedClusterOverP(), dnn::ele_eSeedClusterOverP - e_index_offset);
1929       get(dnn::ele_eSeedClusterOverPout + fill_index_offset_e) =
1930           sp.scale(ele.eSeedClusterOverPout(), dnn::ele_eSeedClusterOverPout - e_index_offset);
1931       get(dnn::ele_eEleClusterOverPout + fill_index_offset_e) =
1932           sp.scale(ele.eEleClusterOverPout(), dnn::ele_eEleClusterOverPout - e_index_offset);
1933       get(dnn::ele_deltaEtaSuperClusterTrackAtVtx + fill_index_offset_e) =
1934           sp.scale(ele.deltaEtaSuperClusterTrackAtVtx(), dnn::ele_deltaEtaSuperClusterTrackAtVtx - e_index_offset);
1935       get(dnn::ele_deltaEtaSeedClusterTrackAtCalo + fill_index_offset_e) =
1936           sp.scale(ele.deltaEtaSeedClusterTrackAtCalo(), dnn::ele_deltaEtaSeedClusterTrackAtCalo - e_index_offset);
1937       get(dnn::ele_deltaEtaEleClusterTrackAtCalo + fill_index_offset_e) =
1938           sp.scale(ele.deltaEtaEleClusterTrackAtCalo(), dnn::ele_deltaEtaEleClusterTrackAtCalo - e_index_offset);
1939       get(dnn::ele_deltaPhiEleClusterTrackAtCalo + fill_index_offset_e) =
1940           sp.scale(ele.deltaPhiEleClusterTrackAtCalo(), dnn::ele_deltaPhiEleClusterTrackAtCalo - e_index_offset);
1941       get(dnn::ele_deltaPhiSuperClusterTrackAtVtx + fill_index_offset_e) =
1942           sp.scale(ele.deltaPhiSuperClusterTrackAtVtx(), dnn::ele_deltaPhiSuperClusterTrackAtVtx - e_index_offset);
1943       get(dnn::ele_deltaPhiSeedClusterTrackAtCalo + fill_index_offset_e) =
1944           sp.scale(ele.deltaPhiSeedClusterTrackAtCalo(), dnn::ele_deltaPhiSeedClusterTrackAtCalo - e_index_offset);
1945       get(dnn::ele_mvaInput_earlyBrem + fill_index_offset_e) =
1946           sp.scale(ele.mvaInput().earlyBrem, dnn::ele_mvaInput_earlyBrem - e_index_offset);
1947       get(dnn::ele_mvaInput_lateBrem + fill_index_offset_e) =
1948           sp.scale(ele.mvaInput().lateBrem, dnn::ele_mvaInput_lateBrem - e_index_offset);
1949       get(dnn::ele_mvaInput_sigmaEtaEta + fill_index_offset_e) =
1950           sp.scale(ele.mvaInput().sigmaEtaEta, dnn::ele_mvaInput_sigmaEtaEta - e_index_offset);
1951       get(dnn::ele_mvaInput_hadEnergy + fill_index_offset_e) =
1952           sp.scale(ele.mvaInput().hadEnergy, dnn::ele_mvaInput_hadEnergy - e_index_offset);
1953       get(dnn::ele_mvaInput_deltaEta + fill_index_offset_e) =
1954           sp.scale(ele.mvaInput().deltaEta, dnn::ele_mvaInput_deltaEta - e_index_offset);
1955       const auto& gsfTrack = ele.gsfTrack();
1956       if (gsfTrack.isNonnull()) {
1957         get(dnn::ele_gsfTrack_normalizedChi2 + fill_index_offset_e) =
1958             sp.scale(gsfTrack->normalizedChi2(), dnn::ele_gsfTrack_normalizedChi2 - e_index_offset);
1959         get(dnn::ele_gsfTrack_numberOfValidHits + fill_index_offset_e) =
1960             sp.scale(gsfTrack->numberOfValidHits(), dnn::ele_gsfTrack_numberOfValidHits - e_index_offset);
1961         get(dnn::ele_rel_gsfTrack_pt + fill_index_offset_e) =
1962             sp.scale(gsfTrack->pt() / ele.polarP4().pt(), dnn::ele_rel_gsfTrack_pt - e_index_offset);
1963         get(dnn::ele_gsfTrack_pt_sig + fill_index_offset_e) =
1964             sp.scale(gsfTrack->pt() / gsfTrack->ptError(), dnn::ele_gsfTrack_pt_sig - e_index_offset);
1965       }
1966       const auto& closestCtfTrack = ele.closestCtfTrackRef();
1967       const bool has_closestCtfTrack = closestCtfTrack.isNonnull();
1968       if (has_closestCtfTrack) {
1969         get(dnn::ele_has_closestCtfTrack + fill_index_offset_e) =
1970             sp.scale(has_closestCtfTrack, dnn::ele_has_closestCtfTrack - e_index_offset);
1971         get(dnn::ele_closestCtfTrack_normalizedChi2 + fill_index_offset_e) =
1972             sp.scale(closestCtfTrack->normalizedChi2(), dnn::ele_closestCtfTrack_normalizedChi2 - e_index_offset);
1973         get(dnn::ele_closestCtfTrack_numberOfValidHits + fill_index_offset_e) =
1974             sp.scale(closestCtfTrack->numberOfValidHits(), dnn::ele_closestCtfTrack_numberOfValidHits - e_index_offset);
1975       }
1976     }
1977   }
1978 
1979   template <typename CandidateCastType, typename TauCastType>
1980   void createMuonBlockInputs(unsigned idx,
1981                              const TauCastType& tau,
1982                              const size_t tau_index,
1983                              const edm::RefToBase<reco::BaseTau> tau_ref,
1984                              const reco::Vertex& pv,
1985                              double rho,
1986                              const std::vector<pat::Muon>* muons,
1987                              const edm::View<reco::Candidate>& pfCands,
1988                              const Cell& cell_map,
1989                              TauFunc tau_funcs,
1990                              bool is_inner) {
1991     namespace dnn = dnn_inputs_v2::MuonBlockInputs;
1992     namespace sc = deep_tau::Scaling;
1993     sc::FeatureT ft_global = sc::FeatureT::GridGlobal;
1994     sc::FeatureT ft_PFmu = sc::FeatureT::PfCand_muon;
1995     sc::FeatureT ft_mu = sc::FeatureT::Muon;
1996 
1997     // needed to remap indices from scaling vectors to those from dnn_inputs_v2::MuonBlockInputs
1998     int PFmu_index_offset = scalingParamsMap_->at(std::make_pair(ft_global, false)).mean_.size();
1999     int mu_index_offset = PFmu_index_offset + scalingParamsMap_->at(std::make_pair(ft_PFmu, false)).mean_.size();
2000 
2001     tensorflow::Tensor& inputs = *muonTensor_.at(is_inner);
2002 
2003     const auto& get = [&](int var_index) -> float& { return inputs.tensor<float, 4>()(idx, 0, 0, var_index); };
2004 
2005     const bool valid_index_pf_muon = cell_map.count(CellObjectType::PfCand_muon);
2006     const bool valid_index_muon = cell_map.count(CellObjectType::Muon);
2007 
2008     if (!cell_map.empty()) {
2009       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_global, false));
2010       get(dnn::rho) = sp.scale(rho, dnn::rho);
2011       get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), dnn::tau_pt);
2012       get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), dnn::tau_eta);
2013       get(dnn::tau_inside_ecal_crack) = sp.scale(isInEcalCrack(tau.polarP4().eta()), dnn::tau_inside_ecal_crack);
2014     }
2015     if (valid_index_pf_muon) {
2016       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFmu, is_inner));
2017       size_t index_pf_muon = cell_map.at(CellObjectType::PfCand_muon);
2018       const auto& muon_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_pf_muon));
2019 
2020       get(dnn::pfCand_muon_valid) = sp.scale(valid_index_pf_muon, dnn::pfCand_muon_valid - PFmu_index_offset);
2021       get(dnn::pfCand_muon_rel_pt) =
2022           sp.scale(muon_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_muon_rel_pt - PFmu_index_offset);
2023       get(dnn::pfCand_muon_deta) =
2024           sp.scale(muon_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_muon_deta - PFmu_index_offset);
2025       get(dnn::pfCand_muon_dphi) =
2026           sp.scale(dPhi(tau.polarP4(), muon_cand.polarP4()), dnn::pfCand_muon_dphi - PFmu_index_offset);
2027       get(dnn::pfCand_muon_pvAssociationQuality) = sp.scale<int>(
2028           candFunc::getPvAssocationQuality(muon_cand), dnn::pfCand_muon_pvAssociationQuality - PFmu_index_offset);
2029       get(dnn::pfCand_muon_fromPV) =
2030           sp.scale<int>(candFunc::getFromPV(muon_cand), dnn::pfCand_muon_fromPV - PFmu_index_offset);
2031       get(dnn::pfCand_muon_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(muon_cand, 0.9786588f),
2032                                                               dnn::pfCand_muon_puppiWeight - PFmu_index_offset)
2033                                                    : sp.scale(candFunc::getPuppiWeight(muon_cand, 0.8132477f),
2034                                                               dnn::pfCand_muon_puppiWeight - PFmu_index_offset);
2035       get(dnn::pfCand_muon_charge) = sp.scale(muon_cand.charge(), dnn::pfCand_muon_charge - PFmu_index_offset);
2036       get(dnn::pfCand_muon_lostInnerHits) =
2037           sp.scale<int>(candFunc::getLostInnerHits(muon_cand, 0), dnn::pfCand_muon_lostInnerHits - PFmu_index_offset);
2038       get(dnn::pfCand_muon_numberOfPixelHits) = sp.scale(candFunc::getNumberOfPixelHits(muon_cand, 0),
2039                                                          dnn::pfCand_muon_numberOfPixelHits - PFmu_index_offset);
2040       get(dnn::pfCand_muon_vertex_dx) =
2041           sp.scale(muon_cand.vertex().x() - pv.position().x(), dnn::pfCand_muon_vertex_dx - PFmu_index_offset);
2042       get(dnn::pfCand_muon_vertex_dy) =
2043           sp.scale(muon_cand.vertex().y() - pv.position().y(), dnn::pfCand_muon_vertex_dy - PFmu_index_offset);
2044       get(dnn::pfCand_muon_vertex_dz) =
2045           sp.scale(muon_cand.vertex().z() - pv.position().z(), dnn::pfCand_muon_vertex_dz - PFmu_index_offset);
2046       get(dnn::pfCand_muon_vertex_dx_tauFL) =
2047           sp.scale(muon_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
2048                    dnn::pfCand_muon_vertex_dx_tauFL - PFmu_index_offset);
2049       get(dnn::pfCand_muon_vertex_dy_tauFL) =
2050           sp.scale(muon_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
2051                    dnn::pfCand_muon_vertex_dy_tauFL - PFmu_index_offset);
2052       get(dnn::pfCand_muon_vertex_dz_tauFL) =
2053           sp.scale(muon_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
2054                    dnn::pfCand_muon_vertex_dz_tauFL - PFmu_index_offset);
2055 
2056       const bool hasTrackDetails = candFunc::getHasTrackDetails(muon_cand);
2057       if (hasTrackDetails) {
2058         get(dnn::pfCand_muon_hasTrackDetails) =
2059             sp.scale(hasTrackDetails, dnn::pfCand_muon_hasTrackDetails - PFmu_index_offset);
2060         get(dnn::pfCand_muon_dxy) = sp.scale(candFunc::getTauDxy(muon_cand), dnn::pfCand_muon_dxy - PFmu_index_offset);
2061         get(dnn::pfCand_muon_dxy_sig) = sp.scale(std::abs(candFunc::getTauDxy(muon_cand)) / muon_cand.dxyError(),
2062                                                  dnn::pfCand_muon_dxy_sig - PFmu_index_offset);
2063         get(dnn::pfCand_muon_dz) = sp.scale(candFunc::getTauDz(muon_cand), dnn::pfCand_muon_dz - PFmu_index_offset);
2064         get(dnn::pfCand_muon_dz_sig) = sp.scale(std::abs(candFunc::getTauDz(muon_cand)) / muon_cand.dzError(),
2065                                                 dnn::pfCand_muon_dz_sig - PFmu_index_offset);
2066         get(dnn::pfCand_muon_track_chi2_ndof) =
2067             candFunc::getPseudoTrack(muon_cand).ndof() > 0
2068                 ? sp.scale(candFunc::getPseudoTrack(muon_cand).chi2() / candFunc::getPseudoTrack(muon_cand).ndof(),
2069                            dnn::pfCand_muon_track_chi2_ndof - PFmu_index_offset)
2070                 : 0;
2071         get(dnn::pfCand_muon_track_ndof) =
2072             candFunc::getPseudoTrack(muon_cand).ndof() > 0
2073                 ? sp.scale(candFunc::getPseudoTrack(muon_cand).ndof(), dnn::pfCand_muon_track_ndof - PFmu_index_offset)
2074                 : 0;
2075       }
2076     }
2077     if (valid_index_muon) {
2078       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_mu, is_inner));
2079       size_t index_muon = cell_map.at(CellObjectType::Muon);
2080       const auto& muon = muons->at(index_muon);
2081 
2082       get(dnn::muon_valid) = sp.scale(valid_index_muon, dnn::muon_valid - mu_index_offset);
2083       get(dnn::muon_rel_pt) = sp.scale(muon.polarP4().pt() / tau.polarP4().pt(), dnn::muon_rel_pt - mu_index_offset);
2084       get(dnn::muon_deta) = sp.scale(muon.polarP4().eta() - tau.polarP4().eta(), dnn::muon_deta - mu_index_offset);
2085       get(dnn::muon_dphi) = sp.scale(dPhi(tau.polarP4(), muon.polarP4()), dnn::muon_dphi - mu_index_offset);
2086       get(dnn::muon_dxy) = sp.scale(muon.dB(pat::Muon::PV2D), dnn::muon_dxy - mu_index_offset);
2087       get(dnn::muon_dxy_sig) =
2088           sp.scale(std::abs(muon.dB(pat::Muon::PV2D)) / muon.edB(pat::Muon::PV2D), dnn::muon_dxy_sig - mu_index_offset);
2089 
2090       const bool normalizedChi2_valid = muon.globalTrack().isNonnull() && muon.normChi2() >= 0;
2091       if (normalizedChi2_valid) {
2092         get(dnn::muon_normalizedChi2_valid) =
2093             sp.scale(normalizedChi2_valid, dnn::muon_normalizedChi2_valid - mu_index_offset);
2094         get(dnn::muon_normalizedChi2) = sp.scale(muon.normChi2(), dnn::muon_normalizedChi2 - mu_index_offset);
2095         if (muon.innerTrack().isNonnull())
2096           get(dnn::muon_numberOfValidHits) =
2097               sp.scale(muon.numberOfValidHits(), dnn::muon_numberOfValidHits - mu_index_offset);
2098       }
2099       get(dnn::muon_segmentCompatibility) =
2100           sp.scale(muon.segmentCompatibility(), dnn::muon_segmentCompatibility - mu_index_offset);
2101       get(dnn::muon_caloCompatibility) =
2102           sp.scale(muon.caloCompatibility(), dnn::muon_caloCompatibility - mu_index_offset);
2103 
2104       const bool pfEcalEnergy_valid = muon.pfEcalEnergy() >= 0;
2105       if (pfEcalEnergy_valid) {
2106         get(dnn::muon_pfEcalEnergy_valid) =
2107             sp.scale(pfEcalEnergy_valid, dnn::muon_pfEcalEnergy_valid - mu_index_offset);
2108         get(dnn::muon_rel_pfEcalEnergy) =
2109             sp.scale(muon.pfEcalEnergy() / muon.polarP4().pt(), dnn::muon_rel_pfEcalEnergy - mu_index_offset);
2110       }
2111 
2112       MuonHitMatchV2 hit_match(muon);
2113       static const std::map<int, std::pair<int, int>> muonMatchHitVars = {
2114           {MuonSubdetId::DT, {dnn::muon_n_matches_DT_1, dnn::muon_n_hits_DT_1}},
2115           {MuonSubdetId::CSC, {dnn::muon_n_matches_CSC_1, dnn::muon_n_hits_CSC_1}},
2116           {MuonSubdetId::RPC, {dnn::muon_n_matches_RPC_1, dnn::muon_n_hits_RPC_1}}};
2117 
2118       for (int subdet : hit_match.MuonHitMatchV2::consideredSubdets()) {
2119         const auto& matchHitVar = muonMatchHitVars.at(subdet);
2120         for (int station = MuonHitMatchV2::first_station_id; station <= MuonHitMatchV2::last_station_id; ++station) {
2121           const unsigned n_matches = hit_match.nMatches(subdet, station);
2122           const unsigned n_hits = hit_match.nHits(subdet, station);
2123           get(matchHitVar.first + station - 1) = sp.scale(n_matches, matchHitVar.first + station - 1 - mu_index_offset);
2124           get(matchHitVar.second + station - 1) = sp.scale(n_hits, matchHitVar.second + station - 1 - mu_index_offset);
2125         }
2126       }
2127     }
2128   }
2129 
2130   template <typename CandidateCastType, typename TauCastType>
2131   void createHadronsBlockInputs(unsigned idx,
2132                                 const TauCastType& tau,
2133                                 const size_t tau_index,
2134                                 const edm::RefToBase<reco::BaseTau> tau_ref,
2135                                 const reco::Vertex& pv,
2136                                 double rho,
2137                                 const edm::View<reco::Candidate>& pfCands,
2138                                 const Cell& cell_map,
2139                                 TauFunc tau_funcs,
2140                                 bool is_inner) {
2141     namespace dnn = dnn_inputs_v2::HadronBlockInputs;
2142     namespace sc = deep_tau::Scaling;
2143     sc::FeatureT ft_global = sc::FeatureT::GridGlobal;
2144     sc::FeatureT ft_PFchH = sc::FeatureT::PfCand_chHad;
2145     sc::FeatureT ft_PFnH = sc::FeatureT::PfCand_nHad;
2146 
2147     // needed to remap indices from scaling vectors to those from dnn_inputs_v2::HadronBlockInputs
2148     int PFchH_index_offset = scalingParamsMap_->at(std::make_pair(ft_global, false)).mean_.size();
2149     int PFnH_index_offset = PFchH_index_offset + scalingParamsMap_->at(std::make_pair(ft_PFchH, false)).mean_.size();
2150 
2151     tensorflow::Tensor& inputs = *hadronsTensor_.at(is_inner);
2152 
2153     const auto& get = [&](int var_index) -> float& { return inputs.tensor<float, 4>()(idx, 0, 0, var_index); };
2154 
2155     const bool valid_chH = cell_map.count(CellObjectType::PfCand_chargedHadron);
2156     const bool valid_nH = cell_map.count(CellObjectType::PfCand_neutralHadron);
2157 
2158     if (!cell_map.empty()) {
2159       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_global, false));
2160       get(dnn::rho) = sp.scale(rho, dnn::rho);
2161       get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), dnn::tau_pt);
2162       get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), dnn::tau_eta);
2163       get(dnn::tau_inside_ecal_crack) = sp.scale(isInEcalCrack(tau.polarP4().eta()), dnn::tau_inside_ecal_crack);
2164     }
2165     if (valid_chH) {
2166       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFchH, is_inner));
2167       size_t index_chH = cell_map.at(CellObjectType::PfCand_chargedHadron);
2168       const auto& chH_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_chH));
2169 
2170       get(dnn::pfCand_chHad_valid) = sp.scale(valid_chH, dnn::pfCand_chHad_valid - PFchH_index_offset);
2171       get(dnn::pfCand_chHad_rel_pt) =
2172           sp.scale(chH_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_chHad_rel_pt - PFchH_index_offset);
2173       get(dnn::pfCand_chHad_deta) =
2174           sp.scale(chH_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_chHad_deta - PFchH_index_offset);
2175       get(dnn::pfCand_chHad_dphi) =
2176           sp.scale(dPhi(tau.polarP4(), chH_cand.polarP4()), dnn::pfCand_chHad_dphi - PFchH_index_offset);
2177       get(dnn::pfCand_chHad_leadChargedHadrCand) =
2178           sp.scale(&chH_cand == dynamic_cast<const CandidateCastType*>(tau.leadChargedHadrCand().get()),
2179                    dnn::pfCand_chHad_leadChargedHadrCand - PFchH_index_offset);
2180       get(dnn::pfCand_chHad_pvAssociationQuality) = sp.scale<int>(
2181           candFunc::getPvAssocationQuality(chH_cand), dnn::pfCand_chHad_pvAssociationQuality - PFchH_index_offset);
2182       get(dnn::pfCand_chHad_fromPV) =
2183           sp.scale<int>(candFunc::getFromPV(chH_cand), dnn::pfCand_chHad_fromPV - PFchH_index_offset);
2184       const float default_chH_pw_inner = 0.7614090f;
2185       const float default_chH_pw_outer = 0.1974930f;
2186       get(dnn::pfCand_chHad_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(chH_cand, default_chH_pw_inner),
2187                                                                dnn::pfCand_chHad_puppiWeight - PFchH_index_offset)
2188                                                     : sp.scale(candFunc::getPuppiWeight(chH_cand, default_chH_pw_outer),
2189                                                                dnn::pfCand_chHad_puppiWeight - PFchH_index_offset);
2190       get(dnn::pfCand_chHad_puppiWeightNoLep) =
2191           is_inner ? sp.scale(candFunc::getPuppiWeightNoLep(chH_cand, default_chH_pw_inner),
2192                               dnn::pfCand_chHad_puppiWeightNoLep - PFchH_index_offset)
2193                    : sp.scale(candFunc::getPuppiWeightNoLep(chH_cand, default_chH_pw_outer),
2194                               dnn::pfCand_chHad_puppiWeightNoLep - PFchH_index_offset);
2195       get(dnn::pfCand_chHad_charge) = sp.scale(chH_cand.charge(), dnn::pfCand_chHad_charge - PFchH_index_offset);
2196       get(dnn::pfCand_chHad_lostInnerHits) =
2197           sp.scale<int>(candFunc::getLostInnerHits(chH_cand, 0), dnn::pfCand_chHad_lostInnerHits - PFchH_index_offset);
2198       get(dnn::pfCand_chHad_numberOfPixelHits) = sp.scale(candFunc::getNumberOfPixelHits(chH_cand, 0),
2199                                                           dnn::pfCand_chHad_numberOfPixelHits - PFchH_index_offset);
2200       get(dnn::pfCand_chHad_vertex_dx) =
2201           sp.scale(chH_cand.vertex().x() - pv.position().x(), dnn::pfCand_chHad_vertex_dx - PFchH_index_offset);
2202       get(dnn::pfCand_chHad_vertex_dy) =
2203           sp.scale(chH_cand.vertex().y() - pv.position().y(), dnn::pfCand_chHad_vertex_dy - PFchH_index_offset);
2204       get(dnn::pfCand_chHad_vertex_dz) =
2205           sp.scale(chH_cand.vertex().z() - pv.position().z(), dnn::pfCand_chHad_vertex_dz - PFchH_index_offset);
2206       get(dnn::pfCand_chHad_vertex_dx_tauFL) =
2207           sp.scale(chH_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
2208                    dnn::pfCand_chHad_vertex_dx_tauFL - PFchH_index_offset);
2209       get(dnn::pfCand_chHad_vertex_dy_tauFL) =
2210           sp.scale(chH_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
2211                    dnn::pfCand_chHad_vertex_dy_tauFL - PFchH_index_offset);
2212       get(dnn::pfCand_chHad_vertex_dz_tauFL) =
2213           sp.scale(chH_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
2214                    dnn::pfCand_chHad_vertex_dz_tauFL - PFchH_index_offset);
2215 
2216       const bool hasTrackDetails = candFunc::getHasTrackDetails(chH_cand);
2217       if (hasTrackDetails) {
2218         get(dnn::pfCand_chHad_hasTrackDetails) =
2219             sp.scale(hasTrackDetails, dnn::pfCand_chHad_hasTrackDetails - PFchH_index_offset);
2220         get(dnn::pfCand_chHad_dxy) =
2221             sp.scale(candFunc::getTauDxy(chH_cand), dnn::pfCand_chHad_dxy - PFchH_index_offset);
2222         get(dnn::pfCand_chHad_dxy_sig) = sp.scale(std::abs(candFunc::getTauDxy(chH_cand)) / chH_cand.dxyError(),
2223                                                   dnn::pfCand_chHad_dxy_sig - PFchH_index_offset);
2224         get(dnn::pfCand_chHad_dz) = sp.scale(candFunc::getTauDz(chH_cand), dnn::pfCand_chHad_dz - PFchH_index_offset);
2225         get(dnn::pfCand_chHad_dz_sig) = sp.scale(std::abs(candFunc::getTauDz(chH_cand)) / chH_cand.dzError(),
2226                                                  dnn::pfCand_chHad_dz_sig - PFchH_index_offset);
2227         get(dnn::pfCand_chHad_track_chi2_ndof) =
2228             candFunc::getPseudoTrack(chH_cand).ndof() > 0
2229                 ? sp.scale(candFunc::getPseudoTrack(chH_cand).chi2() / candFunc::getPseudoTrack(chH_cand).ndof(),
2230                            dnn::pfCand_chHad_track_chi2_ndof - PFchH_index_offset)
2231                 : 0;
2232         get(dnn::pfCand_chHad_track_ndof) =
2233             candFunc::getPseudoTrack(chH_cand).ndof() > 0
2234                 ? sp.scale(candFunc::getPseudoTrack(chH_cand).ndof(), dnn::pfCand_chHad_track_ndof - PFchH_index_offset)
2235                 : 0;
2236       }
2237       float hcal_fraction = candFunc::getHCalFraction(chH_cand, disable_hcalFraction_workaround_);
2238       get(dnn::pfCand_chHad_hcalFraction) =
2239           sp.scale(hcal_fraction, dnn::pfCand_chHad_hcalFraction - PFchH_index_offset);
2240       get(dnn::pfCand_chHad_rawCaloFraction) =
2241           sp.scale(candFunc::getRawCaloFraction(chH_cand), dnn::pfCand_chHad_rawCaloFraction - PFchH_index_offset);
2242     }
2243     if (valid_nH) {
2244       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFnH, is_inner));
2245       size_t index_nH = cell_map.at(CellObjectType::PfCand_neutralHadron);
2246       const auto& nH_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_nH));
2247 
2248       get(dnn::pfCand_nHad_valid) = sp.scale(valid_nH, dnn::pfCand_nHad_valid - PFnH_index_offset);
2249       get(dnn::pfCand_nHad_rel_pt) =
2250           sp.scale(nH_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_nHad_rel_pt - PFnH_index_offset);
2251       get(dnn::pfCand_nHad_deta) =
2252           sp.scale(nH_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_nHad_deta - PFnH_index_offset);
2253       get(dnn::pfCand_nHad_dphi) =
2254           sp.scale(dPhi(tau.polarP4(), nH_cand.polarP4()), dnn::pfCand_nHad_dphi - PFnH_index_offset);
2255       get(dnn::pfCand_nHad_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(nH_cand, 0.9798355f),
2256                                                               dnn::pfCand_nHad_puppiWeight - PFnH_index_offset)
2257                                                    : sp.scale(candFunc::getPuppiWeight(nH_cand, 0.7813260f),
2258                                                               dnn::pfCand_nHad_puppiWeight - PFnH_index_offset);
2259       get(dnn::pfCand_nHad_puppiWeightNoLep) = is_inner
2260                                                    ? sp.scale(candFunc::getPuppiWeightNoLep(nH_cand, 0.9046796f),
2261                                                               dnn::pfCand_nHad_puppiWeightNoLep - PFnH_index_offset)
2262                                                    : sp.scale(candFunc::getPuppiWeightNoLep(nH_cand, 0.6554860f),
2263                                                               dnn::pfCand_nHad_puppiWeightNoLep - PFnH_index_offset);
2264       float hcal_fraction = candFunc::getHCalFraction(nH_cand, disable_hcalFraction_workaround_);
2265       get(dnn::pfCand_nHad_hcalFraction) = sp.scale(hcal_fraction, dnn::pfCand_nHad_hcalFraction - PFnH_index_offset);
2266     }
2267   }
2268 
2269   static void calculateElectronClusterVars(const pat::Electron* ele, float& elecEe, float& elecEgamma) {
2270     if (ele) {
2271       elecEe = elecEgamma = 0;
2272       auto superCluster = ele->superCluster();
2273       if (superCluster.isNonnull() && superCluster.isAvailable() && superCluster->clusters().isNonnull() &&
2274           superCluster->clusters().isAvailable()) {
2275         for (auto iter = superCluster->clustersBegin(); iter != superCluster->clustersEnd(); ++iter) {
2276           const double energy = (*iter)->energy();
2277           if (iter == superCluster->clustersBegin())
2278             elecEe += energy;
2279           else
2280             elecEgamma += energy;
2281         }
2282       }
2283     } else {
2284       elecEe = elecEgamma = default_value;
2285     }
2286   }
2287 
2288   template <typename CandidateCollection, typename TauCastType>
2289   static void processSignalPFComponents(const TauCastType& tau,
2290                                         const CandidateCollection& candidates,
2291                                         LorentzVectorXYZ& p4_inner,
2292                                         LorentzVectorXYZ& p4_outer,
2293                                         float& pt_inner,
2294                                         float& dEta_inner,
2295                                         float& dPhi_inner,
2296                                         float& m_inner,
2297                                         float& pt_outer,
2298                                         float& dEta_outer,
2299                                         float& dPhi_outer,
2300                                         float& m_outer,
2301                                         float& n_inner,
2302                                         float& n_outer) {
2303     p4_inner = LorentzVectorXYZ(0, 0, 0, 0);
2304     p4_outer = LorentzVectorXYZ(0, 0, 0, 0);
2305     n_inner = 0;
2306     n_outer = 0;
2307 
2308     const double innerSigCone_radius = getInnerSignalConeRadius(tau.pt());
2309     for (const auto& cand : candidates) {
2310       const double dR = reco::deltaR(cand->p4(), tau.leadChargedHadrCand()->p4());
2311       const bool isInside_innerSigCone = dR < innerSigCone_radius;
2312       if (isInside_innerSigCone) {
2313         p4_inner += cand->p4();
2314         ++n_inner;
2315       } else {
2316         p4_outer += cand->p4();
2317         ++n_outer;
2318       }
2319     }
2320 
2321     pt_inner = n_inner != 0 ? p4_inner.Pt() : default_value;
2322     dEta_inner = n_inner != 0 ? dEta(p4_inner, tau.p4()) : default_value;
2323     dPhi_inner = n_inner != 0 ? dPhi(p4_inner, tau.p4()) : default_value;
2324     m_inner = n_inner != 0 ? p4_inner.mass() : default_value;
2325 
2326     pt_outer = n_outer != 0 ? p4_outer.Pt() : default_value;
2327     dEta_outer = n_outer != 0 ? dEta(p4_outer, tau.p4()) : default_value;
2328     dPhi_outer = n_outer != 0 ? dPhi(p4_outer, tau.p4()) : default_value;
2329     m_outer = n_outer != 0 ? p4_outer.mass() : default_value;
2330   }
2331 
2332   template <typename CandidateCollection, typename TauCastType>
2333   static void processIsolationPFComponents(const TauCastType& tau,
2334                                            const CandidateCollection& candidates,
2335                                            LorentzVectorXYZ& p4,
2336                                            float& pt,
2337                                            float& d_eta,
2338                                            float& d_phi,
2339                                            float& m,
2340                                            float& n) {
2341     p4 = LorentzVectorXYZ(0, 0, 0, 0);
2342     n = 0;
2343 
2344     for (const auto& cand : candidates) {
2345       p4 += cand->p4();
2346       ++n;
2347     }
2348 
2349     pt = n != 0 ? p4.Pt() : default_value;
2350     d_eta = n != 0 ? dEta(p4, tau.p4()) : default_value;
2351     d_phi = n != 0 ? dPhi(p4, tau.p4()) : default_value;
2352     m = n != 0 ? p4.mass() : default_value;
2353   }
2354 
2355   static double getInnerSignalConeRadius(double pt) {
2356     static constexpr double min_pt = 30., min_radius = 0.05, cone_opening_coef = 3.;
2357     // This is equivalent of the original formula (std::max(std::min(0.1, 3.0/pt), 0.05)
2358     return std::max(cone_opening_coef / std::max(pt, min_pt), min_radius);
2359   }
2360 
2361   // Copied from https://github.com/cms-sw/cmssw/blob/CMSSW_9_4_X/RecoTauTag/RecoTau/plugins/PATTauDiscriminationByMVAIsolationRun2.cc#L218
2362   template <typename TauCastType>
2363   static bool calculateGottfriedJacksonAngleDifference(const TauCastType& tau,
2364                                                        const size_t tau_index,
2365                                                        double& gj_diff,
2366                                                        TauFunc tau_funcs) {
2367     if (tau_funcs.getHasSecondaryVertex(tau, tau_index)) {
2368       static constexpr double mTau = 1.77682;
2369       const double mAOne = tau.p4().M();
2370       const double pAOneMag = tau.p();
2371       const double argumentThetaGJmax = (std::pow(mTau, 2) - std::pow(mAOne, 2)) / (2 * mTau * pAOneMag);
2372       const double argumentThetaGJmeasured = tau.p4().Vect().Dot(tau_funcs.getFlightLength(tau, tau_index)) /
2373                                              (pAOneMag * tau_funcs.getFlightLength(tau, tau_index).R());
2374       if (std::abs(argumentThetaGJmax) <= 1. && std::abs(argumentThetaGJmeasured) <= 1.) {
2375         double thetaGJmax = std::asin(argumentThetaGJmax);
2376         double thetaGJmeasured = std::acos(argumentThetaGJmeasured);
2377         gj_diff = thetaGJmeasured - thetaGJmax;
2378         return true;
2379       }
2380     }
2381     return false;
2382   }
2383 
2384   template <typename TauCastType>
2385   static float calculateGottfriedJacksonAngleDifference(const TauCastType& tau,
2386                                                         const size_t tau_index,
2387                                                         TauFunc tau_funcs) {
2388     double gj_diff;
2389     if (calculateGottfriedJacksonAngleDifference(tau, tau_index, gj_diff, tau_funcs))
2390       return static_cast<float>(gj_diff);
2391     return default_value;
2392   }
2393 
2394   static bool isInEcalCrack(double eta) {
2395     const double abs_eta = std::abs(eta);
2396     return abs_eta > 1.46 && abs_eta < 1.558;
2397   }
2398 
2399   template <typename TauCastType>
2400   static const pat::Electron* findMatchedElectron(const TauCastType& tau,
2401                                                   const std::vector<pat::Electron>* electrons,
2402                                                   double deltaR) {
2403     const double dR2 = deltaR * deltaR;
2404     const pat::Electron* matched_ele = nullptr;
2405     for (const auto& ele : *electrons) {
2406       if (reco::deltaR2(tau.p4(), ele.p4()) < dR2 && (!matched_ele || matched_ele->pt() < ele.pt())) {
2407         matched_ele = &ele;
2408       }
2409     }
2410     return matched_ele;
2411   }
2412 
2413 private:
2414   edm::EDGetTokenT<std::vector<pat::Electron>> electrons_token_;
2415   edm::EDGetTokenT<std::vector<pat::Muon>> muons_token_;
2416   edm::EDGetTokenT<double> rho_token_;
2417   edm::EDGetTokenT<reco::TauDiscriminatorContainer> basicTauDiscriminators_inputToken_;
2418   edm::EDGetTokenT<reco::TauDiscriminatorContainer> basicTauDiscriminatorsdR03_inputToken_;
2419   edm::EDGetTokenT<edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>>
2420       pfTauTransverseImpactParameters_token_;
2421   std::string input_layer_, output_layer_;
2422   const unsigned version_;
2423   const unsigned sub_version_;
2424   const int debug_level;
2425   const bool disable_dxy_pca_;
2426   const bool disable_hcalFraction_workaround_;
2427   const bool disable_CellIndex_workaround_;
2428   std::unique_ptr<tensorflow::Tensor> tauBlockTensor_;
2429   std::array<std::unique_ptr<tensorflow::Tensor>, 2> eGammaTensor_, muonTensor_, hadronsTensor_, convTensor_,
2430       zeroOutputTensor_;
2431   const std::map<std::pair<deep_tau::Scaling::FeatureT, bool>, deep_tau::Scaling::ScalingParams>* scalingParamsMap_;
2432   const bool save_inputs_;
2433   std::ofstream* json_file_;
2434   bool is_first_block_;
2435   int file_counter_;
2436   std::vector<int> tauInputs_indices_;
2437 
2438   //boolean to check if discriminator indices are already mapped
2439   bool discrIndicesMapped_ = false;
2440   std::map<BasicDiscriminator, size_t> basicDiscrIndexMap_;
2441   std::map<BasicDiscriminator, size_t> basicDiscrdR03IndexMap_;
2442 };
2443 
2444 #include "FWCore/Framework/interface/MakerMacros.h"
2445 DEFINE_FWK_MODULE(DeepTauId);