Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-02-08 03:00:25

0001 #ifndef DeepTauIdBase_H
0002 #define DeepTauIdBase_H
0003 
0004 /*
0005  * \class DeepTauIdBase
0006  *
0007  * Base class for DeepTauId producers: DeepTauId and DeepTauIdSONIC
0008  *
0009  */
0010 
0011 #include "DataFormats/PatCandidates/interface/Electron.h"
0012 #include "DataFormats/PatCandidates/interface/Muon.h"
0013 #include "DataFormats/PatCandidates/interface/Tau.h"
0014 #include "DataFormats/TauReco/interface/TauDiscriminatorContainer.h"
0015 #include "DataFormats/TauReco/interface/PFTauDiscriminator.h"
0016 #include "DataFormats/PatCandidates/interface/PATTauDiscriminator.h"
0017 #include "DataFormats/Common/interface/View.h"
0018 #include "DataFormats/Common/interface/RefToBase.h"
0019 #include "DataFormats/Provenance/interface/ProductProvenance.h"
0020 #include "DataFormats/Provenance/interface/ProcessHistoryID.h"
0021 #include "FWCore/Common/interface/Provenance.h"
0022 #include "FWCore/Framework/interface/Event.h"
0023 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0024 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0025 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0026 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0027 #include "RecoTauTag/RecoTau/interface/PFRecoTauClusterVariables.h"
0028 #include "RecoTauTag/RecoTau/interface/DeepTauScaling.h"
0029 #include "RecoTauTag/RecoTau/interface/TauWPThreshold.h"
0030 #include "DataFormats/TauReco/interface/PFTauTransverseImpactParameterAssociation.h"
0031 #include "CommonTools/Utils/interface/StringObjectFunction.h"
0032 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0033 #include "tensorflow/core/util/memmapped_file_system.h"
0034 
0035 #include <Math/VectorUtil.h>
0036 #include <map>
0037 #include <fstream>
0038 #include "oneapi/tbb/concurrent_unordered_set.h"
0039 
0040 namespace deep_tau {
0041   enum BasicDiscriminator {
0042     ChargedIsoPtSum,
0043     NeutralIsoPtSum,
0044     NeutralIsoPtSumWeight,
0045     FootprintCorrection,
0046     PhotonPtSumOutsideSignalCone,
0047     PUcorrPtSum
0048   };
0049 
0050   constexpr int NumberOfOutputs = 4;
0051 }  // namespace deep_tau
0052 
0053 namespace {
0054 
0055   namespace dnn_inputs_v2 {
0056     constexpr int number_of_inner_cell = 11;
0057     constexpr int number_of_outer_cell = 21;
0058     constexpr int number_of_conv_features = 64;
0059     namespace TauBlockInputs {
0060       enum vars {
0061         rho = 0,
0062         tau_pt,
0063         tau_eta,
0064         tau_phi,
0065         tau_mass,
0066         tau_E_over_pt,
0067         tau_charge,
0068         tau_n_charged_prongs,
0069         tau_n_neutral_prongs,
0070         chargedIsoPtSum,
0071         chargedIsoPtSumdR03_over_dR05,
0072         footprintCorrection,
0073         neutralIsoPtSum,
0074         neutralIsoPtSumWeight_over_neutralIsoPtSum,
0075         neutralIsoPtSumWeightdR03_over_neutralIsoPtSum,
0076         neutralIsoPtSumdR03_over_dR05,
0077         photonPtSumOutsideSignalCone,
0078         puCorrPtSum,
0079         tau_dxy_pca_x,
0080         tau_dxy_pca_y,
0081         tau_dxy_pca_z,
0082         tau_dxy_valid,
0083         tau_dxy,
0084         tau_dxy_sig,
0085         tau_ip3d_valid,
0086         tau_ip3d,
0087         tau_ip3d_sig,
0088         tau_dz,
0089         tau_dz_sig_valid,
0090         tau_dz_sig,
0091         tau_flightLength_x,
0092         tau_flightLength_y,
0093         tau_flightLength_z,
0094         tau_flightLength_sig,
0095         tau_pt_weighted_deta_strip,
0096         tau_pt_weighted_dphi_strip,
0097         tau_pt_weighted_dr_signal,
0098         tau_pt_weighted_dr_iso,
0099         tau_leadingTrackNormChi2,
0100         tau_e_ratio_valid,
0101         tau_e_ratio,
0102         tau_gj_angle_diff_valid,
0103         tau_gj_angle_diff,
0104         tau_n_photons,
0105         tau_emFraction,
0106         tau_inside_ecal_crack,
0107         leadChargedCand_etaAtEcalEntrance_minus_tau_eta,
0108         NumberOfInputs
0109       };
0110       inline std::vector<int> varsToDrop = {
0111           tau_phi, tau_dxy_pca_x, tau_dxy_pca_y, tau_dxy_pca_z};  // indices of vars to be dropped in the full var enum
0112     }                                                             // namespace TauBlockInputs
0113 
0114     namespace EgammaBlockInputs {
0115       enum vars {
0116         rho = 0,
0117         tau_pt,
0118         tau_eta,
0119         tau_inside_ecal_crack,
0120         pfCand_ele_valid,
0121         pfCand_ele_rel_pt,
0122         pfCand_ele_deta,
0123         pfCand_ele_dphi,
0124         pfCand_ele_pvAssociationQuality,
0125         pfCand_ele_puppiWeight,
0126         pfCand_ele_charge,
0127         pfCand_ele_lostInnerHits,
0128         pfCand_ele_numberOfPixelHits,
0129         pfCand_ele_vertex_dx,
0130         pfCand_ele_vertex_dy,
0131         pfCand_ele_vertex_dz,
0132         pfCand_ele_vertex_dx_tauFL,
0133         pfCand_ele_vertex_dy_tauFL,
0134         pfCand_ele_vertex_dz_tauFL,
0135         pfCand_ele_hasTrackDetails,
0136         pfCand_ele_dxy,
0137         pfCand_ele_dxy_sig,
0138         pfCand_ele_dz,
0139         pfCand_ele_dz_sig,
0140         pfCand_ele_track_chi2_ndof,
0141         pfCand_ele_track_ndof,
0142         ele_valid,
0143         ele_rel_pt,
0144         ele_deta,
0145         ele_dphi,
0146         ele_cc_valid,
0147         ele_cc_ele_rel_energy,
0148         ele_cc_gamma_rel_energy,
0149         ele_cc_n_gamma,
0150         ele_rel_trackMomentumAtVtx,
0151         ele_rel_trackMomentumAtCalo,
0152         ele_rel_trackMomentumOut,
0153         ele_rel_trackMomentumAtEleClus,
0154         ele_rel_trackMomentumAtVtxWithConstraint,
0155         ele_rel_ecalEnergy,
0156         ele_ecalEnergy_sig,
0157         ele_eSuperClusterOverP,
0158         ele_eSeedClusterOverP,
0159         ele_eSeedClusterOverPout,
0160         ele_eEleClusterOverPout,
0161         ele_deltaEtaSuperClusterTrackAtVtx,
0162         ele_deltaEtaSeedClusterTrackAtCalo,
0163         ele_deltaEtaEleClusterTrackAtCalo,
0164         ele_deltaPhiEleClusterTrackAtCalo,
0165         ele_deltaPhiSuperClusterTrackAtVtx,
0166         ele_deltaPhiSeedClusterTrackAtCalo,
0167         ele_mvaInput_earlyBrem,
0168         ele_mvaInput_lateBrem,
0169         ele_mvaInput_sigmaEtaEta,
0170         ele_mvaInput_hadEnergy,
0171         ele_mvaInput_deltaEta,
0172         ele_gsfTrack_normalizedChi2,
0173         ele_gsfTrack_numberOfValidHits,
0174         ele_rel_gsfTrack_pt,
0175         ele_gsfTrack_pt_sig,
0176         ele_has_closestCtfTrack,
0177         ele_closestCtfTrack_normalizedChi2,
0178         ele_closestCtfTrack_numberOfValidHits,
0179         pfCand_gamma_valid,
0180         pfCand_gamma_rel_pt,
0181         pfCand_gamma_deta,
0182         pfCand_gamma_dphi,
0183         pfCand_gamma_pvAssociationQuality,
0184         pfCand_gamma_fromPV,
0185         pfCand_gamma_puppiWeight,
0186         pfCand_gamma_puppiWeightNoLep,
0187         pfCand_gamma_lostInnerHits,
0188         pfCand_gamma_numberOfPixelHits,
0189         pfCand_gamma_vertex_dx,
0190         pfCand_gamma_vertex_dy,
0191         pfCand_gamma_vertex_dz,
0192         pfCand_gamma_vertex_dx_tauFL,
0193         pfCand_gamma_vertex_dy_tauFL,
0194         pfCand_gamma_vertex_dz_tauFL,
0195         pfCand_gamma_hasTrackDetails,
0196         pfCand_gamma_dxy,
0197         pfCand_gamma_dxy_sig,
0198         pfCand_gamma_dz,
0199         pfCand_gamma_dz_sig,
0200         pfCand_gamma_track_chi2_ndof,
0201         pfCand_gamma_track_ndof,
0202         NumberOfInputs
0203       };
0204     }
0205 
0206     namespace MuonBlockInputs {
0207       enum vars {
0208         rho = 0,
0209         tau_pt,
0210         tau_eta,
0211         tau_inside_ecal_crack,
0212         pfCand_muon_valid,
0213         pfCand_muon_rel_pt,
0214         pfCand_muon_deta,
0215         pfCand_muon_dphi,
0216         pfCand_muon_pvAssociationQuality,
0217         pfCand_muon_fromPV,
0218         pfCand_muon_puppiWeight,
0219         pfCand_muon_charge,
0220         pfCand_muon_lostInnerHits,
0221         pfCand_muon_numberOfPixelHits,
0222         pfCand_muon_vertex_dx,
0223         pfCand_muon_vertex_dy,
0224         pfCand_muon_vertex_dz,
0225         pfCand_muon_vertex_dx_tauFL,
0226         pfCand_muon_vertex_dy_tauFL,
0227         pfCand_muon_vertex_dz_tauFL,
0228         pfCand_muon_hasTrackDetails,
0229         pfCand_muon_dxy,
0230         pfCand_muon_dxy_sig,
0231         pfCand_muon_dz,
0232         pfCand_muon_dz_sig,
0233         pfCand_muon_track_chi2_ndof,
0234         pfCand_muon_track_ndof,
0235         muon_valid,
0236         muon_rel_pt,
0237         muon_deta,
0238         muon_dphi,
0239         muon_dxy,
0240         muon_dxy_sig,
0241         muon_normalizedChi2_valid,
0242         muon_normalizedChi2,
0243         muon_numberOfValidHits,
0244         muon_segmentCompatibility,
0245         muon_caloCompatibility,
0246         muon_pfEcalEnergy_valid,
0247         muon_rel_pfEcalEnergy,
0248         muon_n_matches_DT_1,
0249         muon_n_matches_DT_2,
0250         muon_n_matches_DT_3,
0251         muon_n_matches_DT_4,
0252         muon_n_matches_CSC_1,
0253         muon_n_matches_CSC_2,
0254         muon_n_matches_CSC_3,
0255         muon_n_matches_CSC_4,
0256         muon_n_matches_RPC_1,
0257         muon_n_matches_RPC_2,
0258         muon_n_matches_RPC_3,
0259         muon_n_matches_RPC_4,
0260         muon_n_hits_DT_1,
0261         muon_n_hits_DT_2,
0262         muon_n_hits_DT_3,
0263         muon_n_hits_DT_4,
0264         muon_n_hits_CSC_1,
0265         muon_n_hits_CSC_2,
0266         muon_n_hits_CSC_3,
0267         muon_n_hits_CSC_4,
0268         muon_n_hits_RPC_1,
0269         muon_n_hits_RPC_2,
0270         muon_n_hits_RPC_3,
0271         muon_n_hits_RPC_4,
0272         NumberOfInputs
0273       };
0274     }
0275 
0276     namespace HadronBlockInputs {
0277       enum vars {
0278         rho = 0,
0279         tau_pt,
0280         tau_eta,
0281         tau_inside_ecal_crack,
0282         pfCand_chHad_valid,
0283         pfCand_chHad_rel_pt,
0284         pfCand_chHad_deta,
0285         pfCand_chHad_dphi,
0286         pfCand_chHad_leadChargedHadrCand,
0287         pfCand_chHad_pvAssociationQuality,
0288         pfCand_chHad_fromPV,
0289         pfCand_chHad_puppiWeight,
0290         pfCand_chHad_puppiWeightNoLep,
0291         pfCand_chHad_charge,
0292         pfCand_chHad_lostInnerHits,
0293         pfCand_chHad_numberOfPixelHits,
0294         pfCand_chHad_vertex_dx,
0295         pfCand_chHad_vertex_dy,
0296         pfCand_chHad_vertex_dz,
0297         pfCand_chHad_vertex_dx_tauFL,
0298         pfCand_chHad_vertex_dy_tauFL,
0299         pfCand_chHad_vertex_dz_tauFL,
0300         pfCand_chHad_hasTrackDetails,
0301         pfCand_chHad_dxy,
0302         pfCand_chHad_dxy_sig,
0303         pfCand_chHad_dz,
0304         pfCand_chHad_dz_sig,
0305         pfCand_chHad_track_chi2_ndof,
0306         pfCand_chHad_track_ndof,
0307         pfCand_chHad_hcalFraction,
0308         pfCand_chHad_rawCaloFraction,
0309         pfCand_nHad_valid,
0310         pfCand_nHad_rel_pt,
0311         pfCand_nHad_deta,
0312         pfCand_nHad_dphi,
0313         pfCand_nHad_puppiWeight,
0314         pfCand_nHad_puppiWeightNoLep,
0315         pfCand_nHad_hcalFraction,
0316         NumberOfInputs
0317       };
0318     }
0319   }  // namespace dnn_inputs_v2
0320 
0321   inline float getTauID(const pat::Tau& tau,
0322                         const std::string& tauID,
0323                         float default_value = -999.,
0324                         bool assert_input = true) {
0325     static tbb::concurrent_unordered_set<std::string> isFirstWarning;
0326     if (tau.isTauIDAvailable(tauID)) {
0327       return tau.tauID(tauID);
0328     } else {
0329       if (assert_input) {
0330         throw cms::Exception("DeepTauId")
0331             << "Exception in <getTauID>: No tauID '" << tauID << "' available in pat::Tau given as function argument.";
0332       }
0333       if (isFirstWarning.insert(tauID).second) {
0334         edm::LogWarning("DeepTauID") << "Warning in <getTauID>: No tauID '" << tauID
0335                                      << "' available in pat::Tau given as function argument."
0336                                      << " Using default_value = " << default_value << " instead." << std::endl;
0337       }
0338       return default_value;
0339     }
0340   }
0341 
0342   struct TauFunc {
0343     const reco::TauDiscriminatorContainer* basicTauDiscriminatorCollection;
0344     const reco::TauDiscriminatorContainer* basicTauDiscriminatordR03Collection;
0345     const edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>*
0346         pfTauTransverseImpactParameters;
0347 
0348     using BasicDiscr = deep_tau::BasicDiscriminator;
0349     std::map<BasicDiscr, size_t> indexMap;
0350     std::map<BasicDiscr, size_t> indexMapdR03;
0351 
0352     const float getChargedIsoPtSum(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0353       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::ChargedIsoPtSum));
0354     }
0355     const float getChargedIsoPtSum(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0356       return getTauID(tau, "chargedIsoPtSum");
0357     }
0358     const float getChargedIsoPtSumdR03(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0359       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(indexMapdR03.at(BasicDiscr::ChargedIsoPtSum));
0360     }
0361     const float getChargedIsoPtSumdR03(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0362       return getTauID(tau, "chargedIsoPtSumdR03");
0363     }
0364     const float getFootprintCorrectiondR03(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0365       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(
0366           indexMapdR03.at(BasicDiscr::FootprintCorrection));
0367     }
0368     const float getFootprintCorrectiondR03(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0369       return getTauID(tau, "footprintCorrectiondR03");
0370     }
0371     const float getFootprintCorrection(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0372       return getTauID(tau, "footprintCorrection");
0373     }
0374     const float getNeutralIsoPtSum(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0375       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::NeutralIsoPtSum));
0376     }
0377     const float getNeutralIsoPtSum(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0378       return getTauID(tau, "neutralIsoPtSum");
0379     }
0380     const float getNeutralIsoPtSumdR03(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0381       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(indexMapdR03.at(BasicDiscr::NeutralIsoPtSum));
0382     }
0383     const float getNeutralIsoPtSumdR03(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0384       return getTauID(tau, "neutralIsoPtSumdR03");
0385     }
0386     const float getNeutralIsoPtSumWeight(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0387       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::NeutralIsoPtSumWeight));
0388     }
0389     const float getNeutralIsoPtSumWeight(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0390       return getTauID(tau, "neutralIsoPtSumWeight");
0391     }
0392     const float getNeutralIsoPtSumdR03Weight(const reco::PFTau& tau,
0393                                              const edm::RefToBase<reco::BaseTau> tau_ref) const {
0394       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(
0395           indexMapdR03.at(BasicDiscr::NeutralIsoPtSumWeight));
0396     }
0397     const float getNeutralIsoPtSumdR03Weight(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0398       return getTauID(tau, "neutralIsoPtSumWeightdR03");
0399     }
0400     const float getPhotonPtSumOutsideSignalCone(const reco::PFTau& tau,
0401                                                 const edm::RefToBase<reco::BaseTau> tau_ref) const {
0402       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(
0403           indexMap.at(BasicDiscr::PhotonPtSumOutsideSignalCone));
0404     }
0405     const float getPhotonPtSumOutsideSignalCone(const pat::Tau& tau,
0406                                                 const edm::RefToBase<reco::BaseTau> tau_ref) const {
0407       return getTauID(tau, "photonPtSumOutsideSignalCone");
0408     }
0409     const float getPhotonPtSumOutsideSignalConedR03(const reco::PFTau& tau,
0410                                                     const edm::RefToBase<reco::BaseTau> tau_ref) const {
0411       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(
0412           indexMapdR03.at(BasicDiscr::PhotonPtSumOutsideSignalCone));
0413     }
0414     const float getPhotonPtSumOutsideSignalConedR03(const pat::Tau& tau,
0415                                                     const edm::RefToBase<reco::BaseTau> tau_ref) const {
0416       return getTauID(tau, "photonPtSumOutsideSignalConedR03");
0417     }
0418     const float getPuCorrPtSum(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0419       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::PUcorrPtSum));
0420     }
0421     const float getPuCorrPtSum(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0422       return getTauID(tau, "puCorrPtSum");
0423     }
0424 
0425     auto getdxyPCA(const reco::PFTau& tau, const size_t tau_index) const {
0426       return pfTauTransverseImpactParameters->value(tau_index)->dxy_PCA();
0427     }
0428     auto getdxyPCA(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy_PCA(); }
0429     auto getdxy(const reco::PFTau& tau, const size_t tau_index) const {
0430       return pfTauTransverseImpactParameters->value(tau_index)->dxy();
0431     }
0432     auto getdxy(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy(); }
0433     auto getdxyError(const reco::PFTau& tau, const size_t tau_index) const {
0434       return pfTauTransverseImpactParameters->value(tau_index)->dxy_error();
0435     }
0436     auto getdxyError(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy_error(); }
0437     auto getdxySig(const reco::PFTau& tau, const size_t tau_index) const {
0438       return pfTauTransverseImpactParameters->value(tau_index)->dxy_Sig();
0439     }
0440     auto getdxySig(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy_Sig(); }
0441     auto getip3d(const reco::PFTau& tau, const size_t tau_index) const {
0442       return pfTauTransverseImpactParameters->value(tau_index)->ip3d();
0443     }
0444     auto getip3d(const pat::Tau& tau, const size_t tau_index) const { return tau.ip3d(); }
0445     auto getip3dError(const reco::PFTau& tau, const size_t tau_index) const {
0446       return pfTauTransverseImpactParameters->value(tau_index)->ip3d_error();
0447     }
0448     auto getip3dError(const pat::Tau& tau, const size_t tau_index) const { return tau.ip3d_error(); }
0449     auto getip3dSig(const reco::PFTau& tau, const size_t tau_index) const {
0450       return pfTauTransverseImpactParameters->value(tau_index)->ip3d_Sig();
0451     }
0452     auto getip3dSig(const pat::Tau& tau, const size_t tau_index) const { return tau.ip3d_Sig(); }
0453     auto getHasSecondaryVertex(const reco::PFTau& tau, const size_t tau_index) const {
0454       return pfTauTransverseImpactParameters->value(tau_index)->hasSecondaryVertex();
0455     }
0456     auto getHasSecondaryVertex(const pat::Tau& tau, const size_t tau_index) const { return tau.hasSecondaryVertex(); }
0457     auto getFlightLength(const reco::PFTau& tau, const size_t tau_index) const {
0458       return pfTauTransverseImpactParameters->value(tau_index)->flightLength();
0459     }
0460     auto getFlightLength(const pat::Tau& tau, const size_t tau_index) const { return tau.flightLength(); }
0461     auto getFlightLengthSig(const reco::PFTau& tau, const size_t tau_index) const {
0462       return pfTauTransverseImpactParameters->value(tau_index)->flightLengthSig();
0463     }
0464     auto getFlightLengthSig(const pat::Tau& tau, const size_t tau_index) const { return tau.flightLengthSig(); }
0465 
0466     auto getLeadingTrackNormChi2(const reco::PFTau& tau) { return reco::tau::lead_track_chi2(tau); }
0467     auto getLeadingTrackNormChi2(const pat::Tau& tau) { return tau.leadingTrackNormChi2(); }
0468     auto getEmFraction(const pat::Tau& tau) { return tau.emFraction_MVA(); }
0469     auto getEmFraction(const reco::PFTau& tau) { return tau.emFraction(); }
0470     auto getEtaAtEcalEntrance(const pat::Tau& tau) { return tau.etaAtEcalEntranceLeadChargedCand(); }
0471     auto getEtaAtEcalEntrance(const reco::PFTau& tau) {
0472       return tau.leadPFChargedHadrCand()->positionAtECALEntrance().eta();
0473     }
0474     auto getEcalEnergyLeadingChargedHadr(const reco::PFTau& tau) { return tau.leadPFChargedHadrCand()->ecalEnergy(); }
0475     auto getEcalEnergyLeadingChargedHadr(const pat::Tau& tau) { return tau.ecalEnergyLeadChargedHadrCand(); }
0476     auto getHcalEnergyLeadingChargedHadr(const reco::PFTau& tau) { return tau.leadPFChargedHadrCand()->hcalEnergy(); }
0477     auto getHcalEnergyLeadingChargedHadr(const pat::Tau& tau) { return tau.hcalEnergyLeadChargedHadrCand(); }
0478 
0479     template <typename PreDiscrType>
0480     bool passPrediscriminants(const PreDiscrType prediscriminants,
0481                               const size_t andPrediscriminants,
0482                               const edm::RefToBase<reco::BaseTau> tau_ref) {
0483       bool passesPrediscriminants = (andPrediscriminants ? 1 : 0);
0484       // check tau passes prediscriminants
0485       size_t nPrediscriminants = prediscriminants.size();
0486       for (size_t iDisc = 0; iDisc < nPrediscriminants; ++iDisc) {
0487         // current discriminant result for this tau
0488         double discResult = (*prediscriminants[iDisc].handle)[tau_ref];
0489         uint8_t thisPasses = (discResult > prediscriminants[iDisc].cut) ? 1 : 0;
0490 
0491         // if we are using the AND option, as soon as one fails,
0492         // the result is FAIL and we can quit looping.
0493         // if we are using the OR option as soon as one passes,
0494         // the result is pass and we can quit looping
0495 
0496         // truth table
0497         //        |   result (thisPasses)
0498         //        |     F     |     T
0499         //-----------------------------------
0500         // AND(T) | res=fails |  continue
0501         //        |  break    |
0502         //-----------------------------------
0503         // OR (F) |  continue | res=passes
0504         //        |           |  break
0505 
0506         if (thisPasses ^ andPrediscriminants)  //XOR
0507         {
0508           passesPrediscriminants = (andPrediscriminants ? 0 : 1);  //NOR
0509           break;
0510         }
0511       }
0512       return passesPrediscriminants;
0513     }
0514   };
0515 
0516   namespace candFunc {
0517     inline auto getTauDz(const reco::PFCandidate& cand) { return cand.bestTrack()->dz(); }
0518     inline auto getTauDz(const pat::PackedCandidate& cand) { return cand.dz(); }
0519     inline auto getTauDZSigValid(const reco::PFCandidate& cand) {
0520       return cand.bestTrack() != nullptr && std::isnormal(cand.bestTrack()->dz()) && std::isnormal(cand.dzError()) &&
0521              cand.dzError() > 0;
0522     }
0523     inline auto getTauDZSigValid(const pat::PackedCandidate& cand) {
0524       return cand.hasTrackDetails() && std::isnormal(cand.dz()) && std::isnormal(cand.dzError()) && cand.dzError() > 0;
0525     }
0526     inline auto getTauDxy(const reco::PFCandidate& cand) { return cand.bestTrack()->dxy(); }
0527     inline auto getTauDxy(const pat::PackedCandidate& cand) { return cand.dxy(); }
0528     inline auto getPvAssocationQuality(const reco::PFCandidate& cand) { return 0.7013f; }
0529     inline auto getPvAssocationQuality(const pat::PackedCandidate& cand) { return cand.pvAssociationQuality(); }
0530     inline auto getPuppiWeight(const reco::PFCandidate& cand, const float aod_value) { return aod_value; }
0531     inline auto getPuppiWeight(const pat::PackedCandidate& cand, const float aod_value) { return cand.puppiWeight(); }
0532     inline auto getPuppiWeightNoLep(const reco::PFCandidate& cand, const float aod_value) { return aod_value; }
0533     inline auto getPuppiWeightNoLep(const pat::PackedCandidate& cand, const float aod_value) {
0534       return cand.puppiWeightNoLep();
0535     }
0536     inline auto getLostInnerHits(const reco::PFCandidate& cand, float default_value) {
0537       return cand.bestTrack() != nullptr
0538                  ? cand.bestTrack()->hitPattern().numberOfLostHits(reco::HitPattern::MISSING_INNER_HITS)
0539                  : default_value;
0540     }
0541     inline auto getLostInnerHits(const pat::PackedCandidate& cand, float default_value) { return cand.lostInnerHits(); }
0542     inline auto getNumberOfPixelHits(const reco::PFCandidate& cand, float default_value) {
0543       return cand.bestTrack() != nullptr
0544                  ? cand.bestTrack()->hitPattern().numberOfLostHits(reco::HitPattern::MISSING_INNER_HITS)
0545                  : default_value;
0546     }
0547     inline auto getNumberOfPixelHits(const pat::PackedCandidate& cand, float default_value) {
0548       return cand.numberOfPixelHits();
0549     }
0550     inline auto getHasTrackDetails(const reco::PFCandidate& cand) { return cand.bestTrack() != nullptr; }
0551     inline auto getHasTrackDetails(const pat::PackedCandidate& cand) { return cand.hasTrackDetails(); }
0552     inline auto getPseudoTrack(const reco::PFCandidate& cand) { return *cand.bestTrack(); }
0553     inline auto getPseudoTrack(const pat::PackedCandidate& cand) { return cand.pseudoTrack(); }
0554     inline auto getFromPV(const reco::PFCandidate& cand) { return 0.9994f; }
0555     inline auto getFromPV(const pat::PackedCandidate& cand) { return cand.fromPV(); }
0556     inline auto getHCalFraction(const reco::PFCandidate& cand, bool disable_hcalFraction_workaround) {
0557       return cand.rawHcalEnergy() / (cand.rawHcalEnergy() + cand.rawEcalEnergy());
0558     }
0559     inline auto getHCalFraction(const pat::PackedCandidate& cand, bool disable_hcalFraction_workaround) {
0560       float hcal_fraction = 0.;
0561       if (disable_hcalFraction_workaround) {
0562         // CV: use consistent definition for pfCand_chHad_hcalFraction
0563         //     in DeepTauId.cc code and in TauMLTools/Production/plugins/TauTupleProducer.cc
0564         hcal_fraction = cand.hcalFraction();
0565       } else {
0566         // CV: backwards compatibility with DeepTau training v2p1 used during Run 2
0567         if (cand.pdgId() == 1 || cand.pdgId() == 130) {
0568           hcal_fraction = cand.hcalFraction();
0569         } else if (cand.isIsolatedChargedHadron()) {
0570           hcal_fraction = cand.rawHcalFraction();
0571         }
0572       }
0573       return hcal_fraction;
0574     }
0575     inline auto getRawCaloFraction(const reco::PFCandidate& cand) {
0576       return (cand.rawEcalEnergy() + cand.rawHcalEnergy()) / cand.energy();
0577     }
0578     inline auto getRawCaloFraction(const pat::PackedCandidate& cand) { return cand.rawCaloFraction(); }
0579   };  // namespace candFunc
0580 
0581   template <typename LVector1, typename LVector2>
0582   float dEta(const LVector1& p4, const LVector2& tau_p4) {
0583     return static_cast<float>(p4.eta() - tau_p4.eta());
0584   }
0585 
0586   template <typename LVector1, typename LVector2>
0587   float dPhi(const LVector1& p4_1, const LVector2& p4_2) {
0588     return static_cast<float>(reco::deltaPhi(p4_2.phi(), p4_1.phi()));
0589   }
0590 
0591   struct MuonHitMatchV2 {
0592     static constexpr size_t n_muon_stations = 4;
0593     static constexpr int first_station_id = 1;
0594     static constexpr int last_station_id = first_station_id + n_muon_stations - 1;
0595     using CountArray = std::array<unsigned, n_muon_stations>;
0596     using CountMap = std::map<int, CountArray>;
0597 
0598     const std::vector<int>& consideredSubdets() {
0599       static const std::vector<int> subdets = {MuonSubdetId::DT, MuonSubdetId::CSC, MuonSubdetId::RPC};
0600       return subdets;
0601     }
0602 
0603     const std::string& subdetName(int subdet) {
0604       static const std::map<int, std::string> subdet_names = {
0605           {MuonSubdetId::DT, "DT"}, {MuonSubdetId::CSC, "CSC"}, {MuonSubdetId::RPC, "RPC"}};
0606       if (!subdet_names.count(subdet))
0607         throw cms::Exception("MuonHitMatch") << "Subdet name for subdet id " << subdet << " not found.";
0608       return subdet_names.at(subdet);
0609     }
0610 
0611     size_t getStationIndex(int station, bool throw_exception) const {
0612       if (station < first_station_id || station > last_station_id) {
0613         if (throw_exception)
0614           throw cms::Exception("MuonHitMatch") << "Station id is out of range";
0615         return std::numeric_limits<size_t>::max();
0616       }
0617       return static_cast<size_t>(station - 1);
0618     }
0619 
0620     MuonHitMatchV2(const pat::Muon& muon) {
0621       for (int subdet : consideredSubdets()) {
0622         n_matches[subdet].fill(0);
0623         n_hits[subdet].fill(0);
0624       }
0625 
0626       countMatches(muon, n_matches);
0627       countHits(muon, n_hits);
0628     }
0629 
0630     void countMatches(const pat::Muon& muon, CountMap& n_matches) {
0631       for (const auto& segment : muon.matches()) {
0632         if (segment.segmentMatches.empty() && segment.rpcMatches.empty())
0633           continue;
0634         if (n_matches.count(segment.detector())) {
0635           const size_t station_index = getStationIndex(segment.station(), true);
0636           ++n_matches.at(segment.detector()).at(station_index);
0637         }
0638       }
0639     }
0640 
0641     void countHits(const pat::Muon& muon, CountMap& n_hits) {
0642       if (muon.outerTrack().isNonnull()) {
0643         const auto& hit_pattern = muon.outerTrack()->hitPattern();
0644         for (int hit_index = 0; hit_index < hit_pattern.numberOfAllHits(reco::HitPattern::TRACK_HITS); ++hit_index) {
0645           auto hit_id = hit_pattern.getHitPattern(reco::HitPattern::TRACK_HITS, hit_index);
0646           if (hit_id == 0)
0647             break;
0648           if (hit_pattern.muonHitFilter(hit_id) && (hit_pattern.getHitType(hit_id) == TrackingRecHit::valid ||
0649                                                     hit_pattern.getHitType(hit_id) == TrackingRecHit::bad)) {
0650             const size_t station_index = getStationIndex(hit_pattern.getMuonStation(hit_id), false);
0651             if (station_index < n_muon_stations) {
0652               CountArray* muon_n_hits = nullptr;
0653               if (hit_pattern.muonDTHitFilter(hit_id))
0654                 muon_n_hits = &n_hits.at(MuonSubdetId::DT);
0655               else if (hit_pattern.muonCSCHitFilter(hit_id))
0656                 muon_n_hits = &n_hits.at(MuonSubdetId::CSC);
0657               else if (hit_pattern.muonRPCHitFilter(hit_id))
0658                 muon_n_hits = &n_hits.at(MuonSubdetId::RPC);
0659 
0660               if (muon_n_hits)
0661                 ++muon_n_hits->at(station_index);
0662             }
0663           }
0664         }
0665       }
0666     }
0667 
0668     unsigned nMatches(int subdet, int station) const {
0669       if (!n_matches.count(subdet))
0670         throw cms::Exception("MuonHitMatch") << "Subdet " << subdet << " not found.";
0671       const size_t station_index = getStationIndex(station, true);
0672       return n_matches.at(subdet).at(station_index);
0673     }
0674 
0675     unsigned nHits(int subdet, int station) const {
0676       if (!n_hits.count(subdet))
0677         throw cms::Exception("MuonHitMatch") << "Subdet " << subdet << " not found.";
0678       const size_t station_index = getStationIndex(station, true);
0679       return n_hits.at(subdet).at(station_index);
0680     }
0681 
0682     unsigned countMuonStationsWithMatches(int first_station, int last_station) const {
0683       static const std::map<int, std::vector<bool>> masks = {
0684           {MuonSubdetId::DT, {false, false, false, false}},
0685           {MuonSubdetId::CSC, {true, false, false, false}},
0686           {MuonSubdetId::RPC, {false, false, false, false}},
0687       };
0688       const size_t first_station_index = getStationIndex(first_station, true);
0689       const size_t last_station_index = getStationIndex(last_station, true);
0690       unsigned cnt = 0;
0691       for (size_t n = first_station_index; n <= last_station_index; ++n) {
0692         for (const auto& match : n_matches) {
0693           if (!masks.at(match.first).at(n) && match.second.at(n) > 0)
0694             ++cnt;
0695         }
0696       }
0697       return cnt;
0698     }
0699 
0700     unsigned countMuonStationsWithHits(int first_station, int last_station) const {
0701       static const std::map<int, std::vector<bool>> masks = {
0702           {MuonSubdetId::DT, {false, false, false, false}},
0703           {MuonSubdetId::CSC, {false, false, false, false}},
0704           {MuonSubdetId::RPC, {false, false, false, false}},
0705       };
0706 
0707       const size_t first_station_index = getStationIndex(first_station, true);
0708       const size_t last_station_index = getStationIndex(last_station, true);
0709       unsigned cnt = 0;
0710       for (size_t n = first_station_index; n <= last_station_index; ++n) {
0711         for (const auto& hit : n_hits) {
0712           if (!masks.at(hit.first).at(n) && hit.second.at(n) > 0)
0713             ++cnt;
0714         }
0715       }
0716       return cnt;
0717     }
0718 
0719   private:
0720     CountMap n_matches, n_hits;
0721   };
0722 
0723   enum class CellObjectType {
0724     PfCand_electron,
0725     PfCand_muon,
0726     PfCand_chargedHadron,
0727     PfCand_neutralHadron,
0728     PfCand_gamma,
0729     Electron,
0730     Muon,
0731     Other
0732   };
0733 
0734   template <typename Object>
0735   inline CellObjectType GetCellObjectType(const Object&);
0736   template <>
0737   inline CellObjectType GetCellObjectType(const pat::Electron&) {
0738     return CellObjectType::Electron;
0739   }
0740   template <>
0741   inline CellObjectType GetCellObjectType(const pat::Muon&) {
0742     return CellObjectType::Muon;
0743   }
0744 
0745   template <>
0746   inline CellObjectType GetCellObjectType(reco::Candidate const& cand) {
0747     static const std::map<int, CellObjectType> obj_types = {{11, CellObjectType::PfCand_electron},
0748                                                             {13, CellObjectType::PfCand_muon},
0749                                                             {22, CellObjectType::PfCand_gamma},
0750                                                             {130, CellObjectType::PfCand_neutralHadron},
0751                                                             {211, CellObjectType::PfCand_chargedHadron}};
0752 
0753     auto iter = obj_types.find(std::abs(cand.pdgId()));
0754     if (iter == obj_types.end())
0755       return CellObjectType::Other;
0756     return iter->second;
0757   }
0758 
0759   using Cell = std::map<CellObjectType, size_t>;
0760   struct CellIndex {
0761     int eta, phi;
0762 
0763     bool operator<(const CellIndex& other) const {
0764       if (eta != other.eta)
0765         return eta < other.eta;
0766       return phi < other.phi;
0767     }
0768   };
0769 
0770   class CellGrid {
0771   public:
0772     using Map = std::map<CellIndex, Cell>;
0773     using const_iterator = Map::const_iterator;
0774 
0775     CellGrid(unsigned n_cells_eta,
0776              unsigned n_cells_phi,
0777              double cell_size_eta,
0778              double cell_size_phi,
0779              bool disable_CellIndex_workaround)
0780         : nCellsEta(n_cells_eta),
0781           nCellsPhi(n_cells_phi),
0782           nTotal(nCellsEta * nCellsPhi),
0783           cellSizeEta(cell_size_eta),
0784           cellSizePhi(cell_size_phi),
0785           disable_CellIndex_workaround_(disable_CellIndex_workaround) {
0786       if (nCellsEta % 2 != 1 || nCellsEta < 1)
0787         throw cms::Exception("DeepTauId") << "Invalid number of eta cells.";
0788       if (nCellsPhi % 2 != 1 || nCellsPhi < 1)
0789         throw cms::Exception("DeepTauId") << "Invalid number of phi cells.";
0790       if (cellSizeEta <= 0 || cellSizePhi <= 0)
0791         throw cms::Exception("DeepTauId") << "Invalid cell size.";
0792     }
0793 
0794     int maxEtaIndex() const { return static_cast<int>((nCellsEta - 1) / 2); }
0795     int maxPhiIndex() const { return static_cast<int>((nCellsPhi - 1) / 2); }
0796     double maxDeltaEta() const { return cellSizeEta * (0.5 + maxEtaIndex()); }
0797     double maxDeltaPhi() const { return cellSizePhi * (0.5 + maxPhiIndex()); }
0798     int getEtaTensorIndex(const CellIndex& cellIndex) const { return cellIndex.eta + maxEtaIndex(); }
0799     int getPhiTensorIndex(const CellIndex& cellIndex) const { return cellIndex.phi + maxPhiIndex(); }
0800 
0801     bool tryGetCellIndex(double deltaEta, double deltaPhi, CellIndex& cellIndex) const {
0802       const auto getCellIndex = [this](double x, double maxX, double size, int& index) {
0803         const double absX = std::abs(x);
0804         if (absX > maxX)
0805           return false;
0806         double absIndex;
0807         if (disable_CellIndex_workaround_) {
0808           // CV: use consistent definition for CellIndex
0809           //     in DeepTauId.cc code and new DeepTau trainings
0810           absIndex = std::floor(absX / size + 0.5);
0811         } else {
0812           // CV: backwards compatibility with DeepTau training v2p1 used during Run 2
0813           absIndex = std::floor(std::abs(absX / size - 0.5));
0814         }
0815         index = static_cast<int>(std::copysign(absIndex, x));
0816         return true;
0817       };
0818 
0819       return getCellIndex(deltaEta, maxDeltaEta(), cellSizeEta, cellIndex.eta) &&
0820              getCellIndex(deltaPhi, maxDeltaPhi(), cellSizePhi, cellIndex.phi);
0821     }
0822 
0823     size_t num_valid_cells() const { return cells.size(); }
0824     Cell& operator[](const CellIndex& cellIndex) { return cells[cellIndex]; }
0825     const Cell& at(const CellIndex& cellIndex) const { return cells.at(cellIndex); }
0826     size_t count(const CellIndex& cellIndex) const { return cells.count(cellIndex); }
0827     const_iterator find(const CellIndex& cellIndex) const { return cells.find(cellIndex); }
0828     const_iterator begin() const { return cells.begin(); }
0829     const_iterator end() const { return cells.end(); }
0830 
0831   public:
0832     const unsigned nCellsEta, nCellsPhi, nTotal;
0833     const double cellSizeEta, cellSizePhi;
0834 
0835   private:
0836     std::map<CellIndex, Cell> cells;
0837     const bool disable_CellIndex_workaround_;
0838   };
0839 }  // anonymous namespace
0840 
0841 template <class Producer>
0842 class DeepTauIdBase : public Producer {
0843 public:
0844   using TauDiscriminator = reco::TauDiscriminatorContainer;
0845   using TauCollection = edm::View<reco::BaseTau>;
0846   using CandidateCollection = edm::View<reco::Candidate>;
0847   using TauRef = edm::Ref<TauCollection>;
0848   using TauRefProd = edm::RefProd<TauCollection>;
0849   using ElectronCollection = pat::ElectronCollection;
0850   using MuonCollection = pat::MuonCollection;
0851   using LorentzVectorXYZ = ROOT::Math::LorentzVector<ROOT::Math::PxPyPzE4D<double>>;
0852   using Cutter = tau::TauWPThreshold;
0853   using CutterPtr = std::unique_ptr<Cutter>;
0854   using WPList = std::vector<CutterPtr>;
0855 
0856   struct IDOutput {
0857     std::vector<size_t> num_, den_;
0858 
0859     IDOutput(const std::vector<size_t>& num, const std::vector<size_t>& den) : num_(num), den_(den) {}
0860 
0861     // for direct inference, read the output tensorflow::Tensor
0862     float read_value(const tensorflow::Tensor& pred, size_t tau_index, size_t elem) const {
0863       return pred.matrix<float>()(tau_index, elem);
0864     }
0865     // for SONIC, read the output vector
0866     float read_value(const std::vector<std::vector<float>>& pred, size_t tau_index, size_t elem) const {
0867       return pred.at(tau_index).at(elem);
0868     }
0869 
0870     template <typename PredType>
0871     std::unique_ptr<TauDiscriminator> get_value(const edm::Handle<TauCollection>& taus,
0872                                                 const PredType& pred,
0873                                                 const WPList* working_points,
0874                                                 bool is_online) const {
0875       std::vector<reco::SingleTauDiscriminatorContainer> outputbuffer(taus->size());
0876 
0877       for (size_t tau_index = 0; tau_index < taus->size(); ++tau_index) {
0878         float x = 0;
0879         for (size_t num_elem : num_)
0880           x += read_value(pred, tau_index, num_elem);
0881         if (x != 0 && !den_.empty()) {
0882           float den_val = 0;
0883           for (size_t den_elem : den_)
0884             den_val += read_value(pred, tau_index, den_elem);
0885           x = den_val != 0 ? x / den_val : std::numeric_limits<float>::max();
0886         }
0887         outputbuffer[tau_index].rawValues.push_back(x);
0888         if (working_points) {
0889           for (const auto& wp : *working_points) {
0890             const bool pass = x > (*wp)(taus->at(tau_index), is_online);
0891             outputbuffer[tau_index].workingPoints.push_back(pass);
0892           }
0893         }
0894       }
0895       std::unique_ptr<TauDiscriminator> output = std::make_unique<TauDiscriminator>();
0896       reco::TauDiscriminatorContainer::Filler filler(*output);
0897       filler.insert(taus, outputbuffer.begin(), outputbuffer.end());
0898       filler.fill();
0899       return output;
0900     }
0901   };
0902 
0903   using IDOutputCollection = std::map<std::string, IDOutput>;
0904 
0905   static constexpr float default_value = -999.;
0906 
0907   static const IDOutputCollection& GetIDOutputs() {
0908     static constexpr size_t e_index = 0, mu_index = 1, tau_index = 2, jet_index = 3;
0909     static const IDOutputCollection idoutputs_ = {
0910         {"VSe", IDOutput({tau_index}, {e_index, tau_index})},
0911         {"VSmu", IDOutput({tau_index}, {mu_index, tau_index})},
0912         {"VSjet", IDOutput({tau_index}, {jet_index, tau_index})},
0913     };
0914     return idoutputs_;
0915   }
0916 
0917   using BasicDiscriminator = deep_tau::BasicDiscriminator;
0918 
0919   const std::map<BasicDiscriminator, size_t> matchDiscriminatorIndices(
0920       edm::Event const& event,
0921       edm::EDGetTokenT<reco::TauDiscriminatorContainer> discriminatorContainerToken,
0922       std::vector<BasicDiscriminator> requiredDiscr) {
0923     std::map<std::string, size_t> discrIndexMapStr;
0924     auto const aHandle = event.getHandle(discriminatorContainerToken);
0925     auto const aProv = aHandle.provenance();
0926     if (aProv == nullptr)
0927       aHandle.whyFailed()->raise();
0928     const auto& psetsFromProvenance = edm::parameterSet(aProv->stable(), event.processHistory());
0929     auto const idlist = psetsFromProvenance.getParameter<std::vector<edm::ParameterSet>>("IDdefinitions");
0930     for (size_t j = 0; j < idlist.size(); ++j) {
0931       std::string idname = idlist[j].getParameter<std::string>("IDname");
0932       if (discrIndexMapStr.count(idname)) {
0933         throw cms::Exception("DeepTauId")
0934             << "basic discriminator " << idname << " appears more than once in the input.";
0935       }
0936       discrIndexMapStr[idname] = j;
0937     }
0938 
0939     //translate to a map of <BasicDiscriminator, index> and check if all discriminators are present
0940     std::map<BasicDiscriminator, size_t> discrIndexMap;
0941     for (size_t i = 0; i < requiredDiscr.size(); i++) {
0942       if (discrIndexMapStr.find(stringFromDiscriminator_.at(requiredDiscr[i])) == discrIndexMapStr.end())
0943         throw cms::Exception("DeepTauId") << "Basic Discriminator " << stringFromDiscriminator_.at(requiredDiscr[i])
0944                                           << " was not provided in the config file.";
0945       else
0946         discrIndexMap[requiredDiscr[i]] = discrIndexMapStr[stringFromDiscriminator_.at(requiredDiscr[i])];
0947     }
0948     return discrIndexMap;
0949   }
0950 
0951   static void fillDescriptionsHelper(edm::ParameterSetDescription& desc) {
0952     desc.add<edm::InputTag>("electrons", edm::InputTag("slimmedElectrons"));
0953     desc.add<edm::InputTag>("muons", edm::InputTag("slimmedMuons"));
0954     desc.add<edm::InputTag>("taus", edm::InputTag("slimmedTaus"));
0955     desc.add<edm::InputTag>("pfcands", edm::InputTag("packedPFCandidates"));
0956     desc.add<edm::InputTag>("vertices", edm::InputTag("offlineSlimmedPrimaryVertices"));
0957     desc.add<edm::InputTag>("rho", edm::InputTag("fixedGridRhoAll"));
0958     desc.add<bool>("mem_mapped", false);
0959     desc.add<unsigned>("year", 2017);
0960     desc.add<unsigned>("version", 2);
0961     desc.add<unsigned>("sub_version", 1);
0962     desc.add<int>("debug_level", 0);
0963     desc.add<bool>("disable_dxy_pca", false);
0964     desc.add<bool>("disable_hcalFraction_workaround", false);
0965     desc.add<bool>("disable_CellIndex_workaround", false);
0966     desc.add<bool>("save_inputs", false);
0967     desc.add<bool>("is_online", false);
0968 
0969     desc.add<std::vector<std::string>>("VSeWP", {"-1."});
0970     desc.add<std::vector<std::string>>("VSmuWP", {"-1."});
0971     desc.add<std::vector<std::string>>("VSjetWP", {"-1."});
0972 
0973     desc.addUntracked<edm::InputTag>("basicTauDiscriminators", edm::InputTag("basicTauDiscriminators"));
0974     desc.addUntracked<edm::InputTag>("basicTauDiscriminatorsdR03", edm::InputTag("basicTauDiscriminatorsdR03"));
0975     desc.add<edm::InputTag>("pfTauTransverseImpactParameters", edm::InputTag("hpsPFTauTransverseImpactParameters"));
0976 
0977     {
0978       edm::ParameterSetDescription pset_Prediscriminants;
0979       pset_Prediscriminants.add<std::string>("BooleanOperator", "and");
0980       {
0981         edm::ParameterSetDescription psd1;
0982         psd1.add<double>("cut");
0983         psd1.add<edm::InputTag>("Producer");
0984         pset_Prediscriminants.addOptional<edm::ParameterSetDescription>("decayMode", psd1);
0985       }
0986       desc.add<edm::ParameterSetDescription>("Prediscriminants", pset_Prediscriminants);
0987     }
0988   }
0989 
0990 public:
0991   DeepTauIdBase(const edm::ParameterSet& cfg)
0992       : Producer(cfg),
0993         tausToken_(this->template consumes<TauCollection>(cfg.getParameter<edm::InputTag>("taus"))),
0994         pfcandToken_(this->template consumes<CandidateCollection>(cfg.getParameter<edm::InputTag>("pfcands"))),
0995         vtxToken_(this->template consumes<reco::VertexCollection>(cfg.getParameter<edm::InputTag>("vertices"))),
0996         is_online_(cfg.getParameter<bool>("is_online")),
0997         idoutputs_(GetIDOutputs()),
0998         electrons_token_(
0999             this->template consumes<std::vector<pat::Electron>>(cfg.getParameter<edm::InputTag>("electrons"))),
1000         muons_token_(this->template consumes<std::vector<pat::Muon>>(cfg.getParameter<edm::InputTag>("muons"))),
1001         rho_token_(this->template consumes<double>(cfg.getParameter<edm::InputTag>("rho"))),
1002         basicTauDiscriminators_inputToken_(this->template consumes<reco::TauDiscriminatorContainer>(
1003             cfg.getUntrackedParameter<edm::InputTag>("basicTauDiscriminators"))),
1004         basicTauDiscriminatorsdR03_inputToken_(this->template consumes<reco::TauDiscriminatorContainer>(
1005             cfg.getUntrackedParameter<edm::InputTag>("basicTauDiscriminatorsdR03"))),
1006         pfTauTransverseImpactParameters_token_(
1007             this->template consumes<
1008                 edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>>(
1009                 cfg.getParameter<edm::InputTag>("pfTauTransverseImpactParameters"))),
1010         year_(cfg.getParameter<unsigned>("year")),
1011         version_(cfg.getParameter<unsigned>("version")),
1012         sub_version_(cfg.getParameter<unsigned>("sub_version")),
1013         debug_level(cfg.getParameter<int>("debug_level")),
1014         disable_dxy_pca_(cfg.getParameter<bool>("disable_dxy_pca")),
1015         disable_hcalFraction_workaround_(cfg.getParameter<bool>("disable_hcalFraction_workaround")),
1016         disable_CellIndex_workaround_(cfg.getParameter<bool>("disable_CellIndex_workaround")),
1017         save_inputs_(cfg.getParameter<bool>("save_inputs")),
1018         json_file_(nullptr),
1019         file_counter_(0) {
1020     for (const auto& output_desc : idoutputs_) {
1021       this->template produces<TauDiscriminator>(output_desc.first);
1022       const auto& cut_list = cfg.getParameter<std::vector<std::string>>(output_desc.first + "WP");
1023       for (const std::string& cut_str : cut_list) {
1024         workingPoints_[output_desc.first].push_back(std::make_unique<Cutter>(cut_str));
1025       }
1026     }
1027 
1028     // prediscriminant operator
1029     // require the tau to pass the following prediscriminants
1030     const edm::ParameterSet& prediscriminantConfig = cfg.getParameter<edm::ParameterSet>("Prediscriminants");
1031 
1032     // determine boolean operator used on the prediscriminants
1033     std::string pdBoolOperator = prediscriminantConfig.getParameter<std::string>("BooleanOperator");
1034     // convert string to lowercase
1035     transform(pdBoolOperator.begin(), pdBoolOperator.end(), pdBoolOperator.begin(), ::tolower);
1036 
1037     if (pdBoolOperator == "and") {
1038       andPrediscriminants_ = 0x1;  //use chars instead of bools so we can do a bitwise trick later
1039     } else if (pdBoolOperator == "or") {
1040       andPrediscriminants_ = 0x0;
1041     } else {
1042       throw cms::Exception("TauDiscriminationProducerBase")
1043           << "PrediscriminantBooleanOperator defined incorrectly, options are: AND,OR";
1044     }
1045 
1046     // get the list of prediscriminants
1047     std::vector<std::string> prediscriminantsNames =
1048         prediscriminantConfig.getParameterNamesForType<edm::ParameterSet>();
1049 
1050     for (auto const& iDisc : prediscriminantsNames) {
1051       const edm::ParameterSet& iPredisc = prediscriminantConfig.getParameter<edm::ParameterSet>(iDisc);
1052       const edm::InputTag& label = iPredisc.getParameter<edm::InputTag>("Producer");
1053       double cut = iPredisc.getParameter<double>("cut");
1054 
1055       if (is_online_) {
1056         TauDiscInfo<reco::PFTauDiscriminator> thisDiscriminator;
1057         thisDiscriminator.label = label;
1058         thisDiscriminator.cut = cut;
1059         thisDiscriminator.disc_token = this->template consumes<reco::PFTauDiscriminator>(label);
1060         recoPrediscriminants_.push_back(thisDiscriminator);
1061       } else {
1062         TauDiscInfo<pat::PATTauDiscriminator> thisDiscriminator;
1063         thisDiscriminator.label = label;
1064         thisDiscriminator.cut = cut;
1065         thisDiscriminator.disc_token = this->template consumes<pat::PATTauDiscriminator>(label);
1066         patPrediscriminants_.push_back(thisDiscriminator);
1067       }
1068     }
1069     if (version_ == 2) {
1070       using namespace dnn_inputs_v2;
1071       namespace sc = deep_tau::Scaling;
1072       tauInputs_indices_.resize(TauBlockInputs::NumberOfInputs);
1073       std::iota(std::begin(tauInputs_indices_), std::end(tauInputs_indices_), 0);
1074 
1075       if (sub_version_ == 1) {
1076         scalingParamsMap_ = &sc::scalingParamsMap_v2p1;
1077       } else if (sub_version_ == 5) {
1078         std::sort(TauBlockInputs::varsToDrop.begin(), TauBlockInputs::varsToDrop.end());
1079         for (auto v : TauBlockInputs::varsToDrop) {
1080           tauInputs_indices_.at(v) = TauBlockInputs::NumberOfInputs - TauBlockInputs::varsToDrop.size();
1081           for (std::size_t i = v + 1; i < TauBlockInputs::NumberOfInputs; ++i)
1082             tauInputs_indices_.at(i) -= 1;  // shift all the following indices by 1
1083         }
1084         if (year_ == 2026) {
1085           scalingParamsMap_ = &sc::scalingParamsMap_PhaseIIv2p5;
1086         } else {
1087           scalingParamsMap_ = &sc::scalingParamsMap_v2p5;
1088         }
1089       } else
1090         throw cms::Exception("DeepTauId") << "subversion " << sub_version_ << " is not supported.";
1091 
1092       std::map<std::vector<bool>, std::vector<sc::FeatureT>> GridFeatureTypes_map = {
1093           {{false}, {sc::FeatureT::TauFlat, sc::FeatureT::GridGlobal}},  // feature types without inner/outer grid split
1094           {{false, true},
1095            {sc::FeatureT::PfCand_electron,
1096             sc::FeatureT::PfCand_muon,  // feature types with inner/outer grid split
1097             sc::FeatureT::PfCand_chHad,
1098             sc::FeatureT::PfCand_nHad,
1099             sc::FeatureT::PfCand_gamma,
1100             sc::FeatureT::Electron,
1101             sc::FeatureT::Muon}}};
1102 
1103       // check that sizes of mean/std/lim_min/lim_max vectors are equal between each other
1104       for (const auto& p : GridFeatureTypes_map) {
1105         for (auto is_inner : p.first) {
1106           for (auto featureType : p.second) {
1107             const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(featureType, is_inner));
1108             if (!(sp.mean_.size() == sp.std_.size() && sp.mean_.size() == sp.lim_min_.size() &&
1109                   sp.mean_.size() == sp.lim_max_.size()))
1110               throw cms::Exception("DeepTauId") << "sizes of scaling parameter vectors do not match between each other";
1111           }
1112         }
1113       }
1114     } else {
1115       throw cms::Exception("DeepTauId") << "version " << version_ << " is not supported.";
1116     }
1117   }
1118 
1119   template <typename ConsumeType>
1120   struct TauDiscInfo {
1121     edm::InputTag label;
1122     edm::Handle<ConsumeType> handle;
1123     edm::EDGetTokenT<ConsumeType> disc_token;
1124     double cut;
1125     void fill(const edm::Event& evt) { evt.getByToken(disc_token, handle); }
1126   };
1127 
1128   // select boolean operation on prediscriminants (and = 0x01, or = 0x00)
1129   uint8_t andPrediscriminants_;
1130   std::vector<TauDiscInfo<pat::PATTauDiscriminator>> patPrediscriminants_;
1131   std::vector<TauDiscInfo<reco::PFTauDiscriminator>> recoPrediscriminants_;
1132 
1133 protected:
1134   static constexpr float pi = M_PI;
1135 
1136   template <typename PredType>
1137   void createOutputs(edm::Event& event, const PredType& pred, edm::Handle<TauCollection> taus) {
1138     for (const auto& output_desc : idoutputs_) {
1139       const WPList* working_points = nullptr;
1140       if (workingPoints_.find(output_desc.first) != workingPoints_.end()) {
1141         working_points = &workingPoints_.at(output_desc.first);
1142       }
1143       auto result = output_desc.second.get_value(taus, pred, working_points, is_online_);
1144       event.put(std::move(result), output_desc.first);
1145     }
1146   }
1147 
1148   template <typename T>
1149   static float getValue(T value) {
1150     return std::isnormal(value) ? static_cast<float>(value) : 0.f;
1151   }
1152 
1153   template <typename T>
1154   static float getValueLinear(T value, float min_value, float max_value, bool positive) {
1155     const float fixed_value = getValue(value);
1156     const float clamped_value = std::clamp(fixed_value, min_value, max_value);
1157     float transformed_value = (clamped_value - min_value) / (max_value - min_value);
1158     if (!positive)
1159       transformed_value = transformed_value * 2 - 1;
1160     return transformed_value;
1161   }
1162 
1163   template <typename T>
1164   static float getValueNorm(T value, float mean, float sigma, float n_sigmas_max = 5) {
1165     const float fixed_value = getValue(value);
1166     const float norm_value = (fixed_value - mean) / sigma;
1167     return std::clamp(norm_value, -n_sigmas_max, n_sigmas_max);
1168   }
1169 
1170   static bool isAbove(double value, double min) { return std::isnormal(value) && value > min; }
1171 
1172   static bool calculateElectronClusterVarsV2(const pat::Electron& ele,
1173                                              float& cc_ele_energy,
1174                                              float& cc_gamma_energy,
1175                                              int& cc_n_gamma) {
1176     cc_ele_energy = cc_gamma_energy = 0;
1177     cc_n_gamma = 0;
1178     const auto& superCluster = ele.superCluster();
1179     if (superCluster.isNonnull() && superCluster.isAvailable() && superCluster->clusters().isNonnull() &&
1180         superCluster->clusters().isAvailable()) {
1181       for (auto iter = superCluster->clustersBegin(); iter != superCluster->clustersEnd(); ++iter) {
1182         const float energy = static_cast<float>((*iter)->energy());
1183         if (iter == superCluster->clustersBegin())
1184           cc_ele_energy += energy;
1185         else {
1186           cc_gamma_energy += energy;
1187           ++cc_n_gamma;
1188         }
1189       }
1190       return true;
1191     } else
1192       return false;
1193   }
1194 
1195 protected:
1196   // load prediscriminators
1197   void loadPrediscriminants(edm::Event const& event, edm::Handle<TauCollection> const& taus) {
1198     edm::ProductID tauProductID = taus.id();
1199     size_t nPrediscriminants =
1200         patPrediscriminants_.empty() ? recoPrediscriminants_.size() : patPrediscriminants_.size();
1201     for (size_t iDisc = 0; iDisc < nPrediscriminants; ++iDisc) {
1202       edm::ProductID discKeyId;
1203       if (is_online_) {
1204         recoPrediscriminants_[iDisc].fill(event);
1205         discKeyId = recoPrediscriminants_[iDisc].handle->keyProduct().id();
1206       } else {
1207         patPrediscriminants_[iDisc].fill(event);
1208         discKeyId = patPrediscriminants_[iDisc].handle->keyProduct().id();
1209       }
1210 
1211       // Check to make sure the product is correct for the discriminator.
1212       // If not, throw a more informative exception.
1213       if (tauProductID != discKeyId) {
1214         throw cms::Exception("MisconfiguredPrediscriminant")
1215             << "The tau collection has product ID: " << tauProductID
1216             << " but the pre-discriminator is keyed with product ID: " << discKeyId << std::endl;
1217       }
1218     }
1219   }
1220 
1221   template <typename Collection, typename TauCastType>
1222   void fillGrids(const TauCastType& tau, const Collection& objects, CellGrid& inner_grid, CellGrid& outer_grid) {
1223     static constexpr double outer_dR2 = 0.25;  //0.5^2
1224     const double inner_radius = getInnerSignalConeRadius(tau.polarP4().pt());
1225     const double inner_dR2 = std::pow(inner_radius, 2);
1226 
1227     const auto addObject = [&](size_t n, double deta, double dphi, CellGrid& grid) {
1228       const auto& obj = objects.at(n);
1229       const CellObjectType obj_type = GetCellObjectType(obj);
1230       if (obj_type == CellObjectType::Other)
1231         return;
1232       CellIndex cell_index;
1233       if (grid.tryGetCellIndex(deta, dphi, cell_index)) {
1234         Cell& cell = grid[cell_index];
1235         auto iter = cell.find(obj_type);
1236         if (iter != cell.end()) {
1237           const auto& prev_obj = objects.at(iter->second);
1238           if (obj.polarP4().pt() > prev_obj.polarP4().pt())
1239             iter->second = n;
1240         } else {
1241           cell[obj_type] = n;
1242         }
1243       }
1244     };
1245 
1246     for (size_t n = 0; n < objects.size(); ++n) {
1247       const auto& obj = objects.at(n);
1248       const double deta = obj.polarP4().eta() - tau.polarP4().eta();
1249       const double dphi = reco::deltaPhi(obj.polarP4().phi(), tau.polarP4().phi());
1250       const double dR2 = std::pow(deta, 2) + std::pow(dphi, 2);
1251       if (dR2 < inner_dR2)
1252         addObject(n, deta, dphi, inner_grid);
1253       if (dR2 < outer_dR2)
1254         addObject(n, deta, dphi, outer_grid);
1255     }
1256   }
1257 
1258   template <typename CandidateCastType, typename TauCastType, typename TauBlockType>
1259   void createTauBlockInputs(const TauCastType& tau,
1260                             const size_t& tau_index,
1261                             const edm::RefToBase<reco::BaseTau> tau_ref,
1262                             const reco::Vertex& pv,
1263                             double rho,
1264                             TauFunc tau_funcs,
1265                             TauBlockType& tauBlockInputs) {
1266     namespace dnn = dnn_inputs_v2::TauBlockInputs;
1267     namespace sc = deep_tau::Scaling;
1268     namespace candFunc = candFunc;
1269     sc::FeatureT ft = sc::FeatureT::TauFlat;
1270     const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft, false));
1271 
1272     const auto& get = [&](int var_index) -> float& {
1273       if constexpr (std::is_same_v<TauBlockType, std::vector<float>::iterator>) {
1274         return *(tauBlockInputs + tauInputs_indices_.at(var_index));
1275       } else {
1276         return ((tensorflow::Tensor)tauBlockInputs).matrix<float>()(0, tauInputs_indices_.at(var_index));
1277       }
1278     };
1279 
1280     auto leadChargedHadrCand = dynamic_cast<const CandidateCastType*>(tau.leadChargedHadrCand().get());
1281 
1282     get(dnn::rho) = sp.scale(rho, tauInputs_indices_[dnn::rho]);
1283     get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), tauInputs_indices_[dnn::tau_pt]);
1284     get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), tauInputs_indices_[dnn::tau_eta]);
1285     if (sub_version_ == 1) {
1286       get(dnn::tau_phi) = getValueLinear(tau.polarP4().phi(), -pi, pi, false);
1287     }
1288     get(dnn::tau_mass) = sp.scale(tau.polarP4().mass(), tauInputs_indices_[dnn::tau_mass]);
1289     get(dnn::tau_E_over_pt) = sp.scale(tau.p4().energy() / tau.p4().pt(), tauInputs_indices_[dnn::tau_E_over_pt]);
1290     get(dnn::tau_charge) = sp.scale(tau.charge(), tauInputs_indices_[dnn::tau_charge]);
1291     get(dnn::tau_n_charged_prongs) = sp.scale(tau.decayMode() / 5 + 1, tauInputs_indices_[dnn::tau_n_charged_prongs]);
1292     get(dnn::tau_n_neutral_prongs) = sp.scale(tau.decayMode() % 5, tauInputs_indices_[dnn::tau_n_neutral_prongs]);
1293     get(dnn::chargedIsoPtSum) =
1294         sp.scale(tau_funcs.getChargedIsoPtSum(tau, tau_ref), tauInputs_indices_[dnn::chargedIsoPtSum]);
1295     get(dnn::chargedIsoPtSumdR03_over_dR05) =
1296         sp.scale(tau_funcs.getChargedIsoPtSumdR03(tau, tau_ref) / tau_funcs.getChargedIsoPtSum(tau, tau_ref),
1297                  tauInputs_indices_[dnn::chargedIsoPtSumdR03_over_dR05]);
1298     if (sub_version_ == 1)
1299       get(dnn::footprintCorrection) =
1300           sp.scale(tau_funcs.getFootprintCorrectiondR03(tau, tau_ref), tauInputs_indices_[dnn::footprintCorrection]);
1301     else if (sub_version_ == 5) {
1302       if (is_online_)
1303         get(dnn::footprintCorrection) =
1304             sp.scale(tau_funcs.getFootprintCorrectiondR03(tau, tau_ref), tauInputs_indices_[dnn::footprintCorrection]);
1305       else
1306         get(dnn::footprintCorrection) =
1307             sp.scale(tau_funcs.getFootprintCorrection(tau, tau_ref), tauInputs_indices_[dnn::footprintCorrection]);
1308     }
1309     get(dnn::neutralIsoPtSum) =
1310         sp.scale(tau_funcs.getNeutralIsoPtSum(tau, tau_ref), tauInputs_indices_[dnn::neutralIsoPtSum]);
1311     get(dnn::neutralIsoPtSumWeight_over_neutralIsoPtSum) =
1312         sp.scale(tau_funcs.getNeutralIsoPtSumWeight(tau, tau_ref) / tau_funcs.getNeutralIsoPtSum(tau, tau_ref),
1313                  tauInputs_indices_[dnn::neutralIsoPtSumWeight_over_neutralIsoPtSum]);
1314     get(dnn::neutralIsoPtSumWeightdR03_over_neutralIsoPtSum) =
1315         sp.scale(tau_funcs.getNeutralIsoPtSumdR03Weight(tau, tau_ref) / tau_funcs.getNeutralIsoPtSum(tau, tau_ref),
1316                  tauInputs_indices_[dnn::neutralIsoPtSumWeightdR03_over_neutralIsoPtSum]);
1317     get(dnn::neutralIsoPtSumdR03_over_dR05) =
1318         sp.scale(tau_funcs.getNeutralIsoPtSumdR03(tau, tau_ref) / tau_funcs.getNeutralIsoPtSum(tau, tau_ref),
1319                  tauInputs_indices_[dnn::neutralIsoPtSumdR03_over_dR05]);
1320     get(dnn::photonPtSumOutsideSignalCone) = sp.scale(tau_funcs.getPhotonPtSumOutsideSignalCone(tau, tau_ref),
1321                                                       tauInputs_indices_[dnn::photonPtSumOutsideSignalCone]);
1322     get(dnn::puCorrPtSum) = sp.scale(tau_funcs.getPuCorrPtSum(tau, tau_ref), tauInputs_indices_[dnn::puCorrPtSum]);
1323     // The global PCA coordinates were used as inputs during the NN training, but it was decided to disable
1324     // them for the inference, because modeling of dxy_PCA in MC poorly describes the data, and x and y coordinates
1325     // in data results outside of the expected 5 std. dev. input validity range. On the other hand,
1326     // these coordinates are strongly era-dependent. Kept as comment to document what NN expects.
1327     if (sub_version_ == 1) {
1328       if (!disable_dxy_pca_) {
1329         auto const pca = tau_funcs.getdxyPCA(tau, tau_index);
1330         get(dnn::tau_dxy_pca_x) = sp.scale(pca.x(), tauInputs_indices_[dnn::tau_dxy_pca_x]);
1331         get(dnn::tau_dxy_pca_y) = sp.scale(pca.y(), tauInputs_indices_[dnn::tau_dxy_pca_y]);
1332         get(dnn::tau_dxy_pca_z) = sp.scale(pca.z(), tauInputs_indices_[dnn::tau_dxy_pca_z]);
1333       } else {
1334         get(dnn::tau_dxy_pca_x) = 0;
1335         get(dnn::tau_dxy_pca_y) = 0;
1336         get(dnn::tau_dxy_pca_z) = 0;
1337       }
1338     }
1339 
1340     const bool tau_dxy_valid =
1341         isAbove(tau_funcs.getdxy(tau, tau_index), -10) && isAbove(tau_funcs.getdxyError(tau, tau_index), 0);
1342     if (tau_dxy_valid) {
1343       get(dnn::tau_dxy_valid) = sp.scale(tau_dxy_valid, tauInputs_indices_[dnn::tau_dxy_valid]);
1344       get(dnn::tau_dxy) = sp.scale(tau_funcs.getdxy(tau, tau_index), tauInputs_indices_[dnn::tau_dxy]);
1345       get(dnn::tau_dxy_sig) =
1346           sp.scale(std::abs(tau_funcs.getdxy(tau, tau_index)) / tau_funcs.getdxyError(tau, tau_index),
1347                    tauInputs_indices_[dnn::tau_dxy_sig]);
1348     }
1349     const bool tau_ip3d_valid =
1350         isAbove(tau_funcs.getip3d(tau, tau_index), -10) && isAbove(tau_funcs.getip3dError(tau, tau_index), 0);
1351     if (tau_ip3d_valid) {
1352       get(dnn::tau_ip3d_valid) = sp.scale(tau_ip3d_valid, tauInputs_indices_[dnn::tau_ip3d_valid]);
1353       get(dnn::tau_ip3d) = sp.scale(tau_funcs.getip3d(tau, tau_index), tauInputs_indices_[dnn::tau_ip3d]);
1354       get(dnn::tau_ip3d_sig) =
1355           sp.scale(std::abs(tau_funcs.getip3d(tau, tau_index)) / tau_funcs.getip3dError(tau, tau_index),
1356                    tauInputs_indices_[dnn::tau_ip3d_sig]);
1357     }
1358     if (leadChargedHadrCand) {
1359       const bool hasTrackDetails = candFunc::getHasTrackDetails(*leadChargedHadrCand);
1360       const float tau_dz = (is_online_ && !hasTrackDetails) ? 0 : candFunc::getTauDz(*leadChargedHadrCand);
1361       get(dnn::tau_dz) = sp.scale(tau_dz, tauInputs_indices_[dnn::tau_dz]);
1362       get(dnn::tau_dz_sig_valid) =
1363           sp.scale(candFunc::getTauDZSigValid(*leadChargedHadrCand), tauInputs_indices_[dnn::tau_dz_sig_valid]);
1364       const double dzError = hasTrackDetails ? leadChargedHadrCand->dzError() : -999.;
1365       get(dnn::tau_dz_sig) = sp.scale(std::abs(tau_dz) / dzError, tauInputs_indices_[dnn::tau_dz_sig]);
1366     }
1367     get(dnn::tau_flightLength_x) =
1368         sp.scale(tau_funcs.getFlightLength(tau, tau_index).x(), tauInputs_indices_[dnn::tau_flightLength_x]);
1369     get(dnn::tau_flightLength_y) =
1370         sp.scale(tau_funcs.getFlightLength(tau, tau_index).y(), tauInputs_indices_[dnn::tau_flightLength_y]);
1371     get(dnn::tau_flightLength_z) =
1372         sp.scale(tau_funcs.getFlightLength(tau, tau_index).z(), tauInputs_indices_[dnn::tau_flightLength_z]);
1373     if (sub_version_ == 1)
1374       get(dnn::tau_flightLength_sig) = 0.55756444;  //This value is set due to a bug in the training
1375     else if (sub_version_ == 5)
1376       get(dnn::tau_flightLength_sig) =
1377           sp.scale(tau_funcs.getFlightLengthSig(tau, tau_index), tauInputs_indices_[dnn::tau_flightLength_sig]);
1378 
1379     get(dnn::tau_pt_weighted_deta_strip) = sp.scale(reco::tau::pt_weighted_deta_strip(tau, tau.decayMode()),
1380                                                     tauInputs_indices_[dnn::tau_pt_weighted_deta_strip]);
1381 
1382     get(dnn::tau_pt_weighted_dphi_strip) = sp.scale(reco::tau::pt_weighted_dphi_strip(tau, tau.decayMode()),
1383                                                     tauInputs_indices_[dnn::tau_pt_weighted_dphi_strip]);
1384     get(dnn::tau_pt_weighted_dr_signal) = sp.scale(reco::tau::pt_weighted_dr_signal(tau, tau.decayMode()),
1385                                                    tauInputs_indices_[dnn::tau_pt_weighted_dr_signal]);
1386     get(dnn::tau_pt_weighted_dr_iso) =
1387         sp.scale(reco::tau::pt_weighted_dr_iso(tau, tau.decayMode()), tauInputs_indices_[dnn::tau_pt_weighted_dr_iso]);
1388     get(dnn::tau_leadingTrackNormChi2) =
1389         sp.scale(tau_funcs.getLeadingTrackNormChi2(tau), tauInputs_indices_[dnn::tau_leadingTrackNormChi2]);
1390     const auto eratio = reco::tau::eratio(tau);
1391     const bool tau_e_ratio_valid = std::isnormal(eratio) && eratio > 0.f;
1392     get(dnn::tau_e_ratio_valid) = sp.scale(tau_e_ratio_valid, tauInputs_indices_[dnn::tau_e_ratio_valid]);
1393     get(dnn::tau_e_ratio) = tau_e_ratio_valid ? sp.scale(eratio, tauInputs_indices_[dnn::tau_e_ratio]) : 0.f;
1394     const double gj_angle_diff = calculateGottfriedJacksonAngleDifference(tau, tau_index, tau_funcs);
1395     const bool tau_gj_angle_diff_valid = (std::isnormal(gj_angle_diff) || gj_angle_diff == 0) && gj_angle_diff >= 0;
1396     get(dnn::tau_gj_angle_diff_valid) =
1397         sp.scale(tau_gj_angle_diff_valid, tauInputs_indices_[dnn::tau_gj_angle_diff_valid]);
1398     get(dnn::tau_gj_angle_diff) =
1399         tau_gj_angle_diff_valid ? sp.scale(gj_angle_diff, tauInputs_indices_[dnn::tau_gj_angle_diff]) : 0;
1400     get(dnn::tau_n_photons) = sp.scale(reco::tau::n_photons_total(tau), tauInputs_indices_[dnn::tau_n_photons]);
1401     get(dnn::tau_emFraction) = sp.scale(tau_funcs.getEmFraction(tau), tauInputs_indices_[dnn::tau_emFraction]);
1402 
1403     get(dnn::tau_inside_ecal_crack) =
1404         sp.scale(isInEcalCrack(tau.p4().eta()), tauInputs_indices_[dnn::tau_inside_ecal_crack]);
1405     get(dnn::leadChargedCand_etaAtEcalEntrance_minus_tau_eta) =
1406         sp.scale(tau_funcs.getEtaAtEcalEntrance(tau) - tau.p4().eta(),
1407                  tauInputs_indices_[dnn::leadChargedCand_etaAtEcalEntrance_minus_tau_eta]);
1408   }
1409 
1410   template <typename CandidateCastType, typename TauCastType, typename EgammaBlockType>
1411   void createEgammaBlockInputs(unsigned idx,
1412                                const TauCastType& tau,
1413                                const size_t tau_index,
1414                                const edm::RefToBase<reco::BaseTau> tau_ref,
1415                                const reco::Vertex& pv,
1416                                double rho,
1417                                const std::vector<pat::Electron>* electrons,
1418                                const edm::View<reco::Candidate>& pfCands,
1419                                const Cell& cell_map,
1420                                TauFunc tau_funcs,
1421                                bool is_inner,
1422                                EgammaBlockType& egammaBlockInputs) {
1423     namespace dnn = dnn_inputs_v2::EgammaBlockInputs;
1424     namespace sc = deep_tau::Scaling;
1425     namespace candFunc = candFunc;
1426     sc::FeatureT ft_global = sc::FeatureT::GridGlobal;
1427     sc::FeatureT ft_PFe = sc::FeatureT::PfCand_electron;
1428     sc::FeatureT ft_PFg = sc::FeatureT::PfCand_gamma;
1429     sc::FeatureT ft_e = sc::FeatureT::Electron;
1430 
1431     // needed to remap indices from scaling vectors to those from dnn_inputs_v2::EgammaBlockInputs
1432     int PFe_index_offset = scalingParamsMap_->at(std::make_pair(ft_global, false)).mean_.size();
1433     int e_index_offset = PFe_index_offset + scalingParamsMap_->at(std::make_pair(ft_PFe, false)).mean_.size();
1434     int PFg_index_offset = e_index_offset + scalingParamsMap_->at(std::make_pair(ft_e, false)).mean_.size();
1435 
1436     // to account for swapped order of PfCand_gamma and Electron blocks for v2p5 training w.r.t. v2p1
1437     int fill_index_offset_e = 0;
1438     int fill_index_offset_PFg = 0;
1439     if (sub_version_ == 5) {
1440       fill_index_offset_e =
1441           scalingParamsMap_->at(std::make_pair(ft_PFg, false)).mean_.size();  // size of PF gamma features
1442       fill_index_offset_PFg =
1443           -scalingParamsMap_->at(std::make_pair(ft_e, false)).mean_.size();  // size of Electron features
1444     }
1445 
1446     const auto& get = [&](int var_index) -> float& {
1447       if constexpr (std::is_same_v<EgammaBlockType, std::vector<float>::iterator>) {
1448         return *(egammaBlockInputs + var_index);
1449       } else {
1450         return ((tensorflow::Tensor)egammaBlockInputs).tensor<float, 4>()(idx, 0, 0, var_index);
1451       }
1452     };
1453 
1454     const bool valid_index_pf_ele = cell_map.count(CellObjectType::PfCand_electron);
1455     const bool valid_index_pf_gamma = cell_map.count(CellObjectType::PfCand_gamma);
1456     const bool valid_index_ele = cell_map.count(CellObjectType::Electron);
1457 
1458     if (!cell_map.empty()) {
1459       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_global, false));
1460       get(dnn::rho) = sp.scale(rho, dnn::rho);
1461       get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), dnn::tau_pt);
1462       get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), dnn::tau_eta);
1463       get(dnn::tau_inside_ecal_crack) = sp.scale(isInEcalCrack(tau.polarP4().eta()), dnn::tau_inside_ecal_crack);
1464     }
1465     if (valid_index_pf_ele) {
1466       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFe, is_inner));
1467       size_t index_pf_ele = cell_map.at(CellObjectType::PfCand_electron);
1468       const auto& ele_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_pf_ele));
1469 
1470       get(dnn::pfCand_ele_valid) = sp.scale(valid_index_pf_ele, dnn::pfCand_ele_valid - PFe_index_offset);
1471       get(dnn::pfCand_ele_rel_pt) =
1472           sp.scale(ele_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_ele_rel_pt - PFe_index_offset);
1473       get(dnn::pfCand_ele_deta) =
1474           sp.scale(ele_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_ele_deta - PFe_index_offset);
1475       get(dnn::pfCand_ele_dphi) =
1476           sp.scale(dPhi(tau.polarP4(), ele_cand.polarP4()), dnn::pfCand_ele_dphi - PFe_index_offset);
1477       get(dnn::pfCand_ele_pvAssociationQuality) = sp.scale<int>(
1478           candFunc::getPvAssocationQuality(ele_cand), dnn::pfCand_ele_pvAssociationQuality - PFe_index_offset);
1479       get(dnn::pfCand_ele_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(ele_cand, 0.9906834f),
1480                                                              dnn::pfCand_ele_puppiWeight - PFe_index_offset)
1481                                                   : sp.scale(candFunc::getPuppiWeight(ele_cand, 0.9669586f),
1482                                                              dnn::pfCand_ele_puppiWeight - PFe_index_offset);
1483       get(dnn::pfCand_ele_charge) = sp.scale(ele_cand.charge(), dnn::pfCand_ele_charge - PFe_index_offset);
1484       get(dnn::pfCand_ele_lostInnerHits) =
1485           sp.scale<int>(candFunc::getLostInnerHits(ele_cand, 0), dnn::pfCand_ele_lostInnerHits - PFe_index_offset);
1486       get(dnn::pfCand_ele_numberOfPixelHits) =
1487           sp.scale(candFunc::getNumberOfPixelHits(ele_cand, 0), dnn::pfCand_ele_numberOfPixelHits - PFe_index_offset);
1488       get(dnn::pfCand_ele_vertex_dx) =
1489           sp.scale(ele_cand.vertex().x() - pv.position().x(), dnn::pfCand_ele_vertex_dx - PFe_index_offset);
1490       get(dnn::pfCand_ele_vertex_dy) =
1491           sp.scale(ele_cand.vertex().y() - pv.position().y(), dnn::pfCand_ele_vertex_dy - PFe_index_offset);
1492       get(dnn::pfCand_ele_vertex_dz) =
1493           sp.scale(ele_cand.vertex().z() - pv.position().z(), dnn::pfCand_ele_vertex_dz - PFe_index_offset);
1494       get(dnn::pfCand_ele_vertex_dx_tauFL) =
1495           sp.scale(ele_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
1496                    dnn::pfCand_ele_vertex_dx_tauFL - PFe_index_offset);
1497       get(dnn::pfCand_ele_vertex_dy_tauFL) =
1498           sp.scale(ele_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
1499                    dnn::pfCand_ele_vertex_dy_tauFL - PFe_index_offset);
1500       get(dnn::pfCand_ele_vertex_dz_tauFL) =
1501           sp.scale(ele_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
1502                    dnn::pfCand_ele_vertex_dz_tauFL - PFe_index_offset);
1503 
1504       const bool hasTrackDetails = candFunc::getHasTrackDetails(ele_cand);
1505       if (hasTrackDetails) {
1506         get(dnn::pfCand_ele_hasTrackDetails) =
1507             sp.scale(hasTrackDetails, dnn::pfCand_ele_hasTrackDetails - PFe_index_offset);
1508         get(dnn::pfCand_ele_dxy) = sp.scale(candFunc::getTauDxy(ele_cand), dnn::pfCand_ele_dxy - PFe_index_offset);
1509         get(dnn::pfCand_ele_dxy_sig) = sp.scale(std::abs(candFunc::getTauDxy(ele_cand)) / ele_cand.dxyError(),
1510                                                 dnn::pfCand_ele_dxy_sig - PFe_index_offset);
1511         get(dnn::pfCand_ele_dz) = sp.scale(candFunc::getTauDz(ele_cand), dnn::pfCand_ele_dz - PFe_index_offset);
1512         get(dnn::pfCand_ele_dz_sig) = sp.scale(std::abs(candFunc::getTauDz(ele_cand)) / ele_cand.dzError(),
1513                                                dnn::pfCand_ele_dz_sig - PFe_index_offset);
1514         get(dnn::pfCand_ele_track_chi2_ndof) =
1515             candFunc::getPseudoTrack(ele_cand).ndof() > 0
1516                 ? sp.scale(candFunc::getPseudoTrack(ele_cand).chi2() / candFunc::getPseudoTrack(ele_cand).ndof(),
1517                            dnn::pfCand_ele_track_chi2_ndof - PFe_index_offset)
1518                 : 0;
1519         get(dnn::pfCand_ele_track_ndof) =
1520             candFunc::getPseudoTrack(ele_cand).ndof() > 0
1521                 ? sp.scale(candFunc::getPseudoTrack(ele_cand).ndof(), dnn::pfCand_ele_track_ndof - PFe_index_offset)
1522                 : 0;
1523       }
1524     }
1525     if (valid_index_pf_gamma) {
1526       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFg, is_inner));
1527       size_t index_pf_gamma = cell_map.at(CellObjectType::PfCand_gamma);
1528       const auto& gamma_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_pf_gamma));
1529 
1530       get(dnn::pfCand_gamma_valid + fill_index_offset_PFg) =
1531           sp.scale(valid_index_pf_gamma, dnn::pfCand_gamma_valid - PFg_index_offset);
1532       get(dnn::pfCand_gamma_rel_pt + fill_index_offset_PFg) =
1533           sp.scale(gamma_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_gamma_rel_pt - PFg_index_offset);
1534       get(dnn::pfCand_gamma_deta + fill_index_offset_PFg) =
1535           sp.scale(gamma_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_gamma_deta - PFg_index_offset);
1536       get(dnn::pfCand_gamma_dphi + fill_index_offset_PFg) =
1537           sp.scale(dPhi(tau.polarP4(), gamma_cand.polarP4()), dnn::pfCand_gamma_dphi - PFg_index_offset);
1538       get(dnn::pfCand_gamma_pvAssociationQuality + fill_index_offset_PFg) = sp.scale<int>(
1539           candFunc::getPvAssocationQuality(gamma_cand), dnn::pfCand_gamma_pvAssociationQuality - PFg_index_offset);
1540       get(dnn::pfCand_gamma_fromPV + fill_index_offset_PFg) =
1541           sp.scale<int>(candFunc::getFromPV(gamma_cand), dnn::pfCand_gamma_fromPV - PFg_index_offset);
1542       get(dnn::pfCand_gamma_puppiWeight + fill_index_offset_PFg) =
1543           is_inner ? sp.scale(candFunc::getPuppiWeight(gamma_cand, 0.9084110f),
1544                               dnn::pfCand_gamma_puppiWeight - PFg_index_offset)
1545                    : sp.scale(candFunc::getPuppiWeight(gamma_cand, 0.4211567f),
1546                               dnn::pfCand_gamma_puppiWeight - PFg_index_offset);
1547       get(dnn::pfCand_gamma_puppiWeightNoLep + fill_index_offset_PFg) =
1548           is_inner ? sp.scale(candFunc::getPuppiWeightNoLep(gamma_cand, 0.8857716f),
1549                               dnn::pfCand_gamma_puppiWeightNoLep - PFg_index_offset)
1550                    : sp.scale(candFunc::getPuppiWeightNoLep(gamma_cand, 0.3822604f),
1551                               dnn::pfCand_gamma_puppiWeightNoLep - PFg_index_offset);
1552       get(dnn::pfCand_gamma_lostInnerHits + fill_index_offset_PFg) =
1553           sp.scale<int>(candFunc::getLostInnerHits(gamma_cand, 0), dnn::pfCand_gamma_lostInnerHits - PFg_index_offset);
1554       get(dnn::pfCand_gamma_numberOfPixelHits + fill_index_offset_PFg) = sp.scale(
1555           candFunc::getNumberOfPixelHits(gamma_cand, 0), dnn::pfCand_gamma_numberOfPixelHits - PFg_index_offset);
1556       get(dnn::pfCand_gamma_vertex_dx + fill_index_offset_PFg) =
1557           sp.scale(gamma_cand.vertex().x() - pv.position().x(), dnn::pfCand_gamma_vertex_dx - PFg_index_offset);
1558       get(dnn::pfCand_gamma_vertex_dy + fill_index_offset_PFg) =
1559           sp.scale(gamma_cand.vertex().y() - pv.position().y(), dnn::pfCand_gamma_vertex_dy - PFg_index_offset);
1560       get(dnn::pfCand_gamma_vertex_dz + fill_index_offset_PFg) =
1561           sp.scale(gamma_cand.vertex().z() - pv.position().z(), dnn::pfCand_gamma_vertex_dz - PFg_index_offset);
1562       get(dnn::pfCand_gamma_vertex_dx_tauFL + fill_index_offset_PFg) =
1563           sp.scale(gamma_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
1564                    dnn::pfCand_gamma_vertex_dx_tauFL - PFg_index_offset);
1565       get(dnn::pfCand_gamma_vertex_dy_tauFL + fill_index_offset_PFg) =
1566           sp.scale(gamma_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
1567                    dnn::pfCand_gamma_vertex_dy_tauFL - PFg_index_offset);
1568       get(dnn::pfCand_gamma_vertex_dz_tauFL + fill_index_offset_PFg) =
1569           sp.scale(gamma_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
1570                    dnn::pfCand_gamma_vertex_dz_tauFL - PFg_index_offset);
1571       const bool hasTrackDetails = candFunc::getHasTrackDetails(gamma_cand);
1572       if (hasTrackDetails) {
1573         get(dnn::pfCand_gamma_hasTrackDetails + fill_index_offset_PFg) =
1574             sp.scale(hasTrackDetails, dnn::pfCand_gamma_hasTrackDetails - PFg_index_offset);
1575         get(dnn::pfCand_gamma_dxy + fill_index_offset_PFg) =
1576             sp.scale(candFunc::getTauDxy(gamma_cand), dnn::pfCand_gamma_dxy - PFg_index_offset);
1577         get(dnn::pfCand_gamma_dxy_sig + fill_index_offset_PFg) =
1578             sp.scale(std::abs(candFunc::getTauDxy(gamma_cand)) / gamma_cand.dxyError(),
1579                      dnn::pfCand_gamma_dxy_sig - PFg_index_offset);
1580         get(dnn::pfCand_gamma_dz + fill_index_offset_PFg) =
1581             sp.scale(candFunc::getTauDz(gamma_cand), dnn::pfCand_gamma_dz - PFg_index_offset);
1582         get(dnn::pfCand_gamma_dz_sig + fill_index_offset_PFg) =
1583             sp.scale(std::abs(candFunc::getTauDz(gamma_cand)) / gamma_cand.dzError(),
1584                      dnn::pfCand_gamma_dz_sig - PFg_index_offset);
1585         get(dnn::pfCand_gamma_track_chi2_ndof + fill_index_offset_PFg) =
1586             candFunc::getPseudoTrack(gamma_cand).ndof() > 0
1587                 ? sp.scale(candFunc::getPseudoTrack(gamma_cand).chi2() / candFunc::getPseudoTrack(gamma_cand).ndof(),
1588                            dnn::pfCand_gamma_track_chi2_ndof - PFg_index_offset)
1589                 : 0;
1590         get(dnn::pfCand_gamma_track_ndof + fill_index_offset_PFg) =
1591             candFunc::getPseudoTrack(gamma_cand).ndof() > 0
1592                 ? sp.scale(candFunc::getPseudoTrack(gamma_cand).ndof(), dnn::pfCand_gamma_track_ndof - PFg_index_offset)
1593                 : 0;
1594       }
1595     }
1596     if (valid_index_ele) {
1597       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_e, is_inner));
1598       size_t index_ele = cell_map.at(CellObjectType::Electron);
1599       const auto& ele = electrons->at(index_ele);
1600 
1601       get(dnn::ele_valid + fill_index_offset_e) = sp.scale(valid_index_ele, dnn::ele_valid - e_index_offset);
1602       get(dnn::ele_rel_pt + fill_index_offset_e) =
1603           sp.scale(ele.polarP4().pt() / tau.polarP4().pt(), dnn::ele_rel_pt - e_index_offset);
1604       get(dnn::ele_deta + fill_index_offset_e) =
1605           sp.scale(ele.polarP4().eta() - tau.polarP4().eta(), dnn::ele_deta - e_index_offset);
1606       get(dnn::ele_dphi + fill_index_offset_e) =
1607           sp.scale(dPhi(tau.polarP4(), ele.polarP4()), dnn::ele_dphi - e_index_offset);
1608 
1609       float cc_ele_energy, cc_gamma_energy;
1610       int cc_n_gamma;
1611       const bool cc_valid = calculateElectronClusterVarsV2(ele, cc_ele_energy, cc_gamma_energy, cc_n_gamma);
1612       if (cc_valid) {
1613         get(dnn::ele_cc_valid + fill_index_offset_e) = sp.scale(cc_valid, dnn::ele_cc_valid - e_index_offset);
1614         get(dnn::ele_cc_ele_rel_energy + fill_index_offset_e) =
1615             sp.scale(cc_ele_energy / ele.polarP4().pt(), dnn::ele_cc_ele_rel_energy - e_index_offset);
1616         get(dnn::ele_cc_gamma_rel_energy + fill_index_offset_e) =
1617             sp.scale(cc_gamma_energy / cc_ele_energy, dnn::ele_cc_gamma_rel_energy - e_index_offset);
1618         get(dnn::ele_cc_n_gamma + fill_index_offset_e) = sp.scale(cc_n_gamma, dnn::ele_cc_n_gamma - e_index_offset);
1619       }
1620       get(dnn::ele_rel_trackMomentumAtVtx + fill_index_offset_e) =
1621           sp.scale(ele.trackMomentumAtVtx().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumAtVtx - e_index_offset);
1622       get(dnn::ele_rel_trackMomentumAtCalo + fill_index_offset_e) = sp.scale(
1623           ele.trackMomentumAtCalo().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumAtCalo - e_index_offset);
1624       get(dnn::ele_rel_trackMomentumOut + fill_index_offset_e) =
1625           sp.scale(ele.trackMomentumOut().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumOut - e_index_offset);
1626       get(dnn::ele_rel_trackMomentumAtEleClus + fill_index_offset_e) = sp.scale(
1627           ele.trackMomentumAtEleClus().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumAtEleClus - e_index_offset);
1628       get(dnn::ele_rel_trackMomentumAtVtxWithConstraint + fill_index_offset_e) =
1629           sp.scale(ele.trackMomentumAtVtxWithConstraint().R() / ele.polarP4().pt(),
1630                    dnn::ele_rel_trackMomentumAtVtxWithConstraint - e_index_offset);
1631       get(dnn::ele_rel_ecalEnergy + fill_index_offset_e) =
1632           sp.scale(ele.ecalEnergy() / ele.polarP4().pt(), dnn::ele_rel_ecalEnergy - e_index_offset);
1633       get(dnn::ele_ecalEnergy_sig + fill_index_offset_e) =
1634           sp.scale(ele.ecalEnergy() / ele.ecalEnergyError(), dnn::ele_ecalEnergy_sig - e_index_offset);
1635       get(dnn::ele_eSuperClusterOverP + fill_index_offset_e) =
1636           sp.scale(ele.eSuperClusterOverP(), dnn::ele_eSuperClusterOverP - e_index_offset);
1637       get(dnn::ele_eSeedClusterOverP + fill_index_offset_e) =
1638           sp.scale(ele.eSeedClusterOverP(), dnn::ele_eSeedClusterOverP - e_index_offset);
1639       get(dnn::ele_eSeedClusterOverPout + fill_index_offset_e) =
1640           sp.scale(ele.eSeedClusterOverPout(), dnn::ele_eSeedClusterOverPout - e_index_offset);
1641       get(dnn::ele_eEleClusterOverPout + fill_index_offset_e) =
1642           sp.scale(ele.eEleClusterOverPout(), dnn::ele_eEleClusterOverPout - e_index_offset);
1643       get(dnn::ele_deltaEtaSuperClusterTrackAtVtx + fill_index_offset_e) =
1644           sp.scale(ele.deltaEtaSuperClusterTrackAtVtx(), dnn::ele_deltaEtaSuperClusterTrackAtVtx - e_index_offset);
1645       get(dnn::ele_deltaEtaSeedClusterTrackAtCalo + fill_index_offset_e) =
1646           sp.scale(ele.deltaEtaSeedClusterTrackAtCalo(), dnn::ele_deltaEtaSeedClusterTrackAtCalo - e_index_offset);
1647       get(dnn::ele_deltaEtaEleClusterTrackAtCalo + fill_index_offset_e) =
1648           sp.scale(ele.deltaEtaEleClusterTrackAtCalo(), dnn::ele_deltaEtaEleClusterTrackAtCalo - e_index_offset);
1649       get(dnn::ele_deltaPhiEleClusterTrackAtCalo + fill_index_offset_e) =
1650           sp.scale(ele.deltaPhiEleClusterTrackAtCalo(), dnn::ele_deltaPhiEleClusterTrackAtCalo - e_index_offset);
1651       get(dnn::ele_deltaPhiSuperClusterTrackAtVtx + fill_index_offset_e) =
1652           sp.scale(ele.deltaPhiSuperClusterTrackAtVtx(), dnn::ele_deltaPhiSuperClusterTrackAtVtx - e_index_offset);
1653       get(dnn::ele_deltaPhiSeedClusterTrackAtCalo + fill_index_offset_e) =
1654           sp.scale(ele.deltaPhiSeedClusterTrackAtCalo(), dnn::ele_deltaPhiSeedClusterTrackAtCalo - e_index_offset);
1655       const bool mva_valid =
1656           (ele.mvaInput().earlyBrem > -2) ||
1657           (year_ !=
1658            2026);  // Known issue that input can be invalid in Phase2 samples (early/lateBrem==-2, hadEnergy==0, sigmaEtaEta/deltaEta==3.40282e+38). Unknown if also in Run2/3, so don't change there
1659       if (mva_valid) {
1660         get(dnn::ele_mvaInput_earlyBrem + fill_index_offset_e) =
1661             sp.scale(ele.mvaInput().earlyBrem, dnn::ele_mvaInput_earlyBrem - e_index_offset);
1662         get(dnn::ele_mvaInput_lateBrem + fill_index_offset_e) =
1663             sp.scale(ele.mvaInput().lateBrem, dnn::ele_mvaInput_lateBrem - e_index_offset);
1664         get(dnn::ele_mvaInput_sigmaEtaEta + fill_index_offset_e) =
1665             sp.scale(ele.mvaInput().sigmaEtaEta, dnn::ele_mvaInput_sigmaEtaEta - e_index_offset);
1666         get(dnn::ele_mvaInput_hadEnergy + fill_index_offset_e) =
1667             sp.scale(ele.mvaInput().hadEnergy, dnn::ele_mvaInput_hadEnergy - e_index_offset);
1668         get(dnn::ele_mvaInput_deltaEta + fill_index_offset_e) =
1669             sp.scale(ele.mvaInput().deltaEta, dnn::ele_mvaInput_deltaEta - e_index_offset);
1670       }
1671       const auto& gsfTrack = ele.gsfTrack();
1672       if (gsfTrack.isNonnull()) {
1673         get(dnn::ele_gsfTrack_normalizedChi2 + fill_index_offset_e) =
1674             sp.scale(gsfTrack->normalizedChi2(), dnn::ele_gsfTrack_normalizedChi2 - e_index_offset);
1675         get(dnn::ele_gsfTrack_numberOfValidHits + fill_index_offset_e) =
1676             sp.scale(gsfTrack->numberOfValidHits(), dnn::ele_gsfTrack_numberOfValidHits - e_index_offset);
1677         get(dnn::ele_rel_gsfTrack_pt + fill_index_offset_e) =
1678             sp.scale(gsfTrack->pt() / ele.polarP4().pt(), dnn::ele_rel_gsfTrack_pt - e_index_offset);
1679         get(dnn::ele_gsfTrack_pt_sig + fill_index_offset_e) =
1680             sp.scale(gsfTrack->pt() / gsfTrack->ptError(), dnn::ele_gsfTrack_pt_sig - e_index_offset);
1681       }
1682       const auto& closestCtfTrack = ele.closestCtfTrackRef();
1683       const bool has_closestCtfTrack = closestCtfTrack.isNonnull();
1684       if (has_closestCtfTrack) {
1685         get(dnn::ele_has_closestCtfTrack + fill_index_offset_e) =
1686             sp.scale(has_closestCtfTrack, dnn::ele_has_closestCtfTrack - e_index_offset);
1687         get(dnn::ele_closestCtfTrack_normalizedChi2 + fill_index_offset_e) =
1688             sp.scale(closestCtfTrack->normalizedChi2(), dnn::ele_closestCtfTrack_normalizedChi2 - e_index_offset);
1689         get(dnn::ele_closestCtfTrack_numberOfValidHits + fill_index_offset_e) =
1690             sp.scale(closestCtfTrack->numberOfValidHits(), dnn::ele_closestCtfTrack_numberOfValidHits - e_index_offset);
1691       }
1692     }
1693   }
1694 
1695   template <typename CandidateCastType, typename TauCastType, typename MuonBlockType>
1696   void createMuonBlockInputs(unsigned idx,
1697                              const TauCastType& tau,
1698                              const size_t tau_index,
1699                              const edm::RefToBase<reco::BaseTau> tau_ref,
1700                              const reco::Vertex& pv,
1701                              double rho,
1702                              const std::vector<pat::Muon>* muons,
1703                              const edm::View<reco::Candidate>& pfCands,
1704                              const Cell& cell_map,
1705                              TauFunc tau_funcs,
1706                              bool is_inner,
1707                              MuonBlockType& muonBlockInputs) {
1708     namespace dnn = dnn_inputs_v2::MuonBlockInputs;
1709     namespace sc = deep_tau::Scaling;
1710     namespace candFunc = candFunc;
1711     using MuonHitMatchV2 = MuonHitMatchV2;
1712     sc::FeatureT ft_global = sc::FeatureT::GridGlobal;
1713     sc::FeatureT ft_PFmu = sc::FeatureT::PfCand_muon;
1714     sc::FeatureT ft_mu = sc::FeatureT::Muon;
1715 
1716     // needed to remap indices from scaling vectors to those from dnn_inputs_v2::MuonBlockInputs
1717     int PFmu_index_offset = scalingParamsMap_->at(std::make_pair(ft_global, false)).mean_.size();
1718     int mu_index_offset = PFmu_index_offset + scalingParamsMap_->at(std::make_pair(ft_PFmu, false)).mean_.size();
1719 
1720     const auto& get = [&](int var_index) -> float& {
1721       if constexpr (std::is_same_v<MuonBlockType, std::vector<float>::iterator>) {
1722         return *(muonBlockInputs + var_index);
1723       } else {
1724         return ((tensorflow::Tensor)muonBlockInputs).tensor<float, 4>()(idx, 0, 0, var_index);
1725       }
1726     };
1727 
1728     const bool valid_index_pf_muon = cell_map.count(CellObjectType::PfCand_muon);
1729     const bool valid_index_muon = cell_map.count(CellObjectType::Muon);
1730 
1731     if (!cell_map.empty()) {
1732       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_global, false));
1733       get(dnn::rho) = sp.scale(rho, dnn::rho);
1734       get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), dnn::tau_pt);
1735       get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), dnn::tau_eta);
1736       get(dnn::tau_inside_ecal_crack) = sp.scale(isInEcalCrack(tau.polarP4().eta()), dnn::tau_inside_ecal_crack);
1737     }
1738     if (valid_index_pf_muon) {
1739       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFmu, is_inner));
1740       size_t index_pf_muon = cell_map.at(CellObjectType::PfCand_muon);
1741       const auto& muon_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_pf_muon));
1742 
1743       get(dnn::pfCand_muon_valid) = sp.scale(valid_index_pf_muon, dnn::pfCand_muon_valid - PFmu_index_offset);
1744       get(dnn::pfCand_muon_rel_pt) =
1745           sp.scale(muon_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_muon_rel_pt - PFmu_index_offset);
1746       get(dnn::pfCand_muon_deta) =
1747           sp.scale(muon_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_muon_deta - PFmu_index_offset);
1748       get(dnn::pfCand_muon_dphi) =
1749           sp.scale(dPhi(tau.polarP4(), muon_cand.polarP4()), dnn::pfCand_muon_dphi - PFmu_index_offset);
1750       get(dnn::pfCand_muon_pvAssociationQuality) = sp.scale<int>(
1751           candFunc::getPvAssocationQuality(muon_cand), dnn::pfCand_muon_pvAssociationQuality - PFmu_index_offset);
1752       get(dnn::pfCand_muon_fromPV) =
1753           sp.scale<int>(candFunc::getFromPV(muon_cand), dnn::pfCand_muon_fromPV - PFmu_index_offset);
1754       get(dnn::pfCand_muon_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(muon_cand, 0.9786588f),
1755                                                               dnn::pfCand_muon_puppiWeight - PFmu_index_offset)
1756                                                    : sp.scale(candFunc::getPuppiWeight(muon_cand, 0.8132477f),
1757                                                               dnn::pfCand_muon_puppiWeight - PFmu_index_offset);
1758       get(dnn::pfCand_muon_charge) = sp.scale(muon_cand.charge(), dnn::pfCand_muon_charge - PFmu_index_offset);
1759       get(dnn::pfCand_muon_lostInnerHits) =
1760           sp.scale<int>(candFunc::getLostInnerHits(muon_cand, 0), dnn::pfCand_muon_lostInnerHits - PFmu_index_offset);
1761       get(dnn::pfCand_muon_numberOfPixelHits) = sp.scale(candFunc::getNumberOfPixelHits(muon_cand, 0),
1762                                                          dnn::pfCand_muon_numberOfPixelHits - PFmu_index_offset);
1763       get(dnn::pfCand_muon_vertex_dx) =
1764           sp.scale(muon_cand.vertex().x() - pv.position().x(), dnn::pfCand_muon_vertex_dx - PFmu_index_offset);
1765       get(dnn::pfCand_muon_vertex_dy) =
1766           sp.scale(muon_cand.vertex().y() - pv.position().y(), dnn::pfCand_muon_vertex_dy - PFmu_index_offset);
1767       get(dnn::pfCand_muon_vertex_dz) =
1768           sp.scale(muon_cand.vertex().z() - pv.position().z(), dnn::pfCand_muon_vertex_dz - PFmu_index_offset);
1769       get(dnn::pfCand_muon_vertex_dx_tauFL) =
1770           sp.scale(muon_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
1771                    dnn::pfCand_muon_vertex_dx_tauFL - PFmu_index_offset);
1772       get(dnn::pfCand_muon_vertex_dy_tauFL) =
1773           sp.scale(muon_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
1774                    dnn::pfCand_muon_vertex_dy_tauFL - PFmu_index_offset);
1775       get(dnn::pfCand_muon_vertex_dz_tauFL) =
1776           sp.scale(muon_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
1777                    dnn::pfCand_muon_vertex_dz_tauFL - PFmu_index_offset);
1778 
1779       const bool hasTrackDetails = candFunc::getHasTrackDetails(muon_cand);
1780       if (hasTrackDetails) {
1781         get(dnn::pfCand_muon_hasTrackDetails) =
1782             sp.scale(hasTrackDetails, dnn::pfCand_muon_hasTrackDetails - PFmu_index_offset);
1783         get(dnn::pfCand_muon_dxy) = sp.scale(candFunc::getTauDxy(muon_cand), dnn::pfCand_muon_dxy - PFmu_index_offset);
1784         get(dnn::pfCand_muon_dxy_sig) = sp.scale(std::abs(candFunc::getTauDxy(muon_cand)) / muon_cand.dxyError(),
1785                                                  dnn::pfCand_muon_dxy_sig - PFmu_index_offset);
1786         get(dnn::pfCand_muon_dz) = sp.scale(candFunc::getTauDz(muon_cand), dnn::pfCand_muon_dz - PFmu_index_offset);
1787         get(dnn::pfCand_muon_dz_sig) = sp.scale(std::abs(candFunc::getTauDz(muon_cand)) / muon_cand.dzError(),
1788                                                 dnn::pfCand_muon_dz_sig - PFmu_index_offset);
1789         get(dnn::pfCand_muon_track_chi2_ndof) =
1790             candFunc::getPseudoTrack(muon_cand).ndof() > 0
1791                 ? sp.scale(candFunc::getPseudoTrack(muon_cand).chi2() / candFunc::getPseudoTrack(muon_cand).ndof(),
1792                            dnn::pfCand_muon_track_chi2_ndof - PFmu_index_offset)
1793                 : 0;
1794         get(dnn::pfCand_muon_track_ndof) =
1795             candFunc::getPseudoTrack(muon_cand).ndof() > 0
1796                 ? sp.scale(candFunc::getPseudoTrack(muon_cand).ndof(), dnn::pfCand_muon_track_ndof - PFmu_index_offset)
1797                 : 0;
1798       }
1799     }
1800     if (valid_index_muon) {
1801       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_mu, is_inner));
1802       size_t index_muon = cell_map.at(CellObjectType::Muon);
1803       const auto& muon = muons->at(index_muon);
1804 
1805       get(dnn::muon_valid) = sp.scale(valid_index_muon, dnn::muon_valid - mu_index_offset);
1806       get(dnn::muon_rel_pt) = sp.scale(muon.polarP4().pt() / tau.polarP4().pt(), dnn::muon_rel_pt - mu_index_offset);
1807       get(dnn::muon_deta) = sp.scale(muon.polarP4().eta() - tau.polarP4().eta(), dnn::muon_deta - mu_index_offset);
1808       get(dnn::muon_dphi) = sp.scale(dPhi(tau.polarP4(), muon.polarP4()), dnn::muon_dphi - mu_index_offset);
1809       get(dnn::muon_dxy) = sp.scale(muon.dB(pat::Muon::PV2D), dnn::muon_dxy - mu_index_offset);
1810       get(dnn::muon_dxy_sig) =
1811           sp.scale(std::abs(muon.dB(pat::Muon::PV2D)) / muon.edB(pat::Muon::PV2D), dnn::muon_dxy_sig - mu_index_offset);
1812 
1813       const bool normalizedChi2_valid = muon.globalTrack().isNonnull() && muon.normChi2() >= 0;
1814       if (normalizedChi2_valid) {
1815         get(dnn::muon_normalizedChi2_valid) =
1816             sp.scale(normalizedChi2_valid, dnn::muon_normalizedChi2_valid - mu_index_offset);
1817         get(dnn::muon_normalizedChi2) = sp.scale(muon.normChi2(), dnn::muon_normalizedChi2 - mu_index_offset);
1818         if (muon.innerTrack().isNonnull())
1819           get(dnn::muon_numberOfValidHits) =
1820               sp.scale(muon.numberOfValidHits(), dnn::muon_numberOfValidHits - mu_index_offset);
1821       }
1822       get(dnn::muon_segmentCompatibility) =
1823           sp.scale(muon.segmentCompatibility(), dnn::muon_segmentCompatibility - mu_index_offset);
1824       get(dnn::muon_caloCompatibility) =
1825           sp.scale(muon.caloCompatibility(), dnn::muon_caloCompatibility - mu_index_offset);
1826 
1827       const bool pfEcalEnergy_valid = muon.pfEcalEnergy() >= 0;
1828       if (pfEcalEnergy_valid) {
1829         get(dnn::muon_pfEcalEnergy_valid) =
1830             sp.scale(pfEcalEnergy_valid, dnn::muon_pfEcalEnergy_valid - mu_index_offset);
1831         get(dnn::muon_rel_pfEcalEnergy) =
1832             sp.scale(muon.pfEcalEnergy() / muon.polarP4().pt(), dnn::muon_rel_pfEcalEnergy - mu_index_offset);
1833       }
1834 
1835       MuonHitMatchV2 hit_match(muon);
1836       static const std::map<int, std::pair<int, int>> muonMatchHitVars = {
1837           {MuonSubdetId::DT, {dnn::muon_n_matches_DT_1, dnn::muon_n_hits_DT_1}},
1838           {MuonSubdetId::CSC, {dnn::muon_n_matches_CSC_1, dnn::muon_n_hits_CSC_1}},
1839           {MuonSubdetId::RPC, {dnn::muon_n_matches_RPC_1, dnn::muon_n_hits_RPC_1}}};
1840 
1841       for (int subdet : hit_match.MuonHitMatchV2::consideredSubdets()) {
1842         const auto& matchHitVar = muonMatchHitVars.at(subdet);
1843         for (int station = MuonHitMatchV2::first_station_id; station <= MuonHitMatchV2::last_station_id; ++station) {
1844           const unsigned n_matches = hit_match.nMatches(subdet, station);
1845           const unsigned n_hits = hit_match.nHits(subdet, station);
1846           get(matchHitVar.first + station - 1) = sp.scale(n_matches, matchHitVar.first + station - 1 - mu_index_offset);
1847           get(matchHitVar.second + station - 1) = sp.scale(n_hits, matchHitVar.second + station - 1 - mu_index_offset);
1848         }
1849       }
1850     }
1851   }
1852 
1853   template <typename CandidateCastType, typename TauCastType, typename HadronBlockType>
1854   void createHadronsBlockInputs(unsigned idx,
1855                                 const TauCastType& tau,
1856                                 const size_t tau_index,
1857                                 const edm::RefToBase<reco::BaseTau> tau_ref,
1858                                 const reco::Vertex& pv,
1859                                 double rho,
1860                                 const edm::View<reco::Candidate>& pfCands,
1861                                 const Cell& cell_map,
1862                                 TauFunc tau_funcs,
1863                                 bool is_inner,
1864                                 HadronBlockType& hadronBlockInputs) {
1865     namespace dnn = dnn_inputs_v2::HadronBlockInputs;
1866     namespace sc = deep_tau::Scaling;
1867     namespace candFunc = candFunc;
1868     sc::FeatureT ft_global = sc::FeatureT::GridGlobal;
1869     sc::FeatureT ft_PFchH = sc::FeatureT::PfCand_chHad;
1870     sc::FeatureT ft_PFnH = sc::FeatureT::PfCand_nHad;
1871 
1872     // needed to remap indices from scaling vectors to those from dnn_inputs_v2::HadronBlockInputs
1873     int PFchH_index_offset = scalingParamsMap_->at(std::make_pair(ft_global, false)).mean_.size();
1874     int PFnH_index_offset = PFchH_index_offset + scalingParamsMap_->at(std::make_pair(ft_PFchH, false)).mean_.size();
1875 
1876     const auto& get = [&](int var_index) -> float& {
1877       if constexpr (std::is_same_v<HadronBlockType, std::vector<float>::iterator>) {
1878         return *(hadronBlockInputs + var_index);
1879       } else {
1880         return ((tensorflow::Tensor)hadronBlockInputs).tensor<float, 4>()(idx, 0, 0, var_index);
1881       }
1882     };
1883 
1884     const bool valid_chH = cell_map.count(CellObjectType::PfCand_chargedHadron);
1885     const bool valid_nH = cell_map.count(CellObjectType::PfCand_neutralHadron);
1886 
1887     if (!cell_map.empty()) {
1888       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_global, false));
1889       get(dnn::rho) = sp.scale(rho, dnn::rho);
1890       get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), dnn::tau_pt);
1891       get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), dnn::tau_eta);
1892       get(dnn::tau_inside_ecal_crack) = sp.scale(isInEcalCrack(tau.polarP4().eta()), dnn::tau_inside_ecal_crack);
1893     }
1894     if (valid_chH) {
1895       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFchH, is_inner));
1896       size_t index_chH = cell_map.at(CellObjectType::PfCand_chargedHadron);
1897       const auto& chH_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_chH));
1898 
1899       get(dnn::pfCand_chHad_valid) = sp.scale(valid_chH, dnn::pfCand_chHad_valid - PFchH_index_offset);
1900       get(dnn::pfCand_chHad_rel_pt) =
1901           sp.scale(chH_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_chHad_rel_pt - PFchH_index_offset);
1902       get(dnn::pfCand_chHad_deta) =
1903           sp.scale(chH_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_chHad_deta - PFchH_index_offset);
1904       get(dnn::pfCand_chHad_dphi) =
1905           sp.scale(dPhi(tau.polarP4(), chH_cand.polarP4()), dnn::pfCand_chHad_dphi - PFchH_index_offset);
1906       get(dnn::pfCand_chHad_leadChargedHadrCand) =
1907           sp.scale(&chH_cand == dynamic_cast<const CandidateCastType*>(tau.leadChargedHadrCand().get()),
1908                    dnn::pfCand_chHad_leadChargedHadrCand - PFchH_index_offset);
1909       get(dnn::pfCand_chHad_pvAssociationQuality) = sp.scale<int>(
1910           candFunc::getPvAssocationQuality(chH_cand), dnn::pfCand_chHad_pvAssociationQuality - PFchH_index_offset);
1911       get(dnn::pfCand_chHad_fromPV) =
1912           sp.scale<int>(candFunc::getFromPV(chH_cand), dnn::pfCand_chHad_fromPV - PFchH_index_offset);
1913       const float default_chH_pw_inner = 0.7614090f;
1914       const float default_chH_pw_outer = 0.1974930f;
1915       get(dnn::pfCand_chHad_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(chH_cand, default_chH_pw_inner),
1916                                                                dnn::pfCand_chHad_puppiWeight - PFchH_index_offset)
1917                                                     : sp.scale(candFunc::getPuppiWeight(chH_cand, default_chH_pw_outer),
1918                                                                dnn::pfCand_chHad_puppiWeight - PFchH_index_offset);
1919       get(dnn::pfCand_chHad_puppiWeightNoLep) =
1920           is_inner ? sp.scale(candFunc::getPuppiWeightNoLep(chH_cand, default_chH_pw_inner),
1921                               dnn::pfCand_chHad_puppiWeightNoLep - PFchH_index_offset)
1922                    : sp.scale(candFunc::getPuppiWeightNoLep(chH_cand, default_chH_pw_outer),
1923                               dnn::pfCand_chHad_puppiWeightNoLep - PFchH_index_offset);
1924       get(dnn::pfCand_chHad_charge) = sp.scale(chH_cand.charge(), dnn::pfCand_chHad_charge - PFchH_index_offset);
1925       get(dnn::pfCand_chHad_lostInnerHits) =
1926           sp.scale<int>(candFunc::getLostInnerHits(chH_cand, 0), dnn::pfCand_chHad_lostInnerHits - PFchH_index_offset);
1927       get(dnn::pfCand_chHad_numberOfPixelHits) = sp.scale(candFunc::getNumberOfPixelHits(chH_cand, 0),
1928                                                           dnn::pfCand_chHad_numberOfPixelHits - PFchH_index_offset);
1929       get(dnn::pfCand_chHad_vertex_dx) =
1930           sp.scale(chH_cand.vertex().x() - pv.position().x(), dnn::pfCand_chHad_vertex_dx - PFchH_index_offset);
1931       get(dnn::pfCand_chHad_vertex_dy) =
1932           sp.scale(chH_cand.vertex().y() - pv.position().y(), dnn::pfCand_chHad_vertex_dy - PFchH_index_offset);
1933       get(dnn::pfCand_chHad_vertex_dz) =
1934           sp.scale(chH_cand.vertex().z() - pv.position().z(), dnn::pfCand_chHad_vertex_dz - PFchH_index_offset);
1935       get(dnn::pfCand_chHad_vertex_dx_tauFL) =
1936           sp.scale(chH_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
1937                    dnn::pfCand_chHad_vertex_dx_tauFL - PFchH_index_offset);
1938       get(dnn::pfCand_chHad_vertex_dy_tauFL) =
1939           sp.scale(chH_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
1940                    dnn::pfCand_chHad_vertex_dy_tauFL - PFchH_index_offset);
1941       get(dnn::pfCand_chHad_vertex_dz_tauFL) =
1942           sp.scale(chH_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
1943                    dnn::pfCand_chHad_vertex_dz_tauFL - PFchH_index_offset);
1944 
1945       const bool hasTrackDetails = candFunc::getHasTrackDetails(chH_cand);
1946       if (hasTrackDetails) {
1947         get(dnn::pfCand_chHad_hasTrackDetails) =
1948             sp.scale(hasTrackDetails, dnn::pfCand_chHad_hasTrackDetails - PFchH_index_offset);
1949         get(dnn::pfCand_chHad_dxy) =
1950             sp.scale(candFunc::getTauDxy(chH_cand), dnn::pfCand_chHad_dxy - PFchH_index_offset);
1951         get(dnn::pfCand_chHad_dxy_sig) = sp.scale(std::abs(candFunc::getTauDxy(chH_cand)) / chH_cand.dxyError(),
1952                                                   dnn::pfCand_chHad_dxy_sig - PFchH_index_offset);
1953         get(dnn::pfCand_chHad_dz) = sp.scale(candFunc::getTauDz(chH_cand), dnn::pfCand_chHad_dz - PFchH_index_offset);
1954         get(dnn::pfCand_chHad_dz_sig) = sp.scale(std::abs(candFunc::getTauDz(chH_cand)) / chH_cand.dzError(),
1955                                                  dnn::pfCand_chHad_dz_sig - PFchH_index_offset);
1956         get(dnn::pfCand_chHad_track_chi2_ndof) =
1957             candFunc::getPseudoTrack(chH_cand).ndof() > 0
1958                 ? sp.scale(candFunc::getPseudoTrack(chH_cand).chi2() / candFunc::getPseudoTrack(chH_cand).ndof(),
1959                            dnn::pfCand_chHad_track_chi2_ndof - PFchH_index_offset)
1960                 : 0;
1961         get(dnn::pfCand_chHad_track_ndof) =
1962             candFunc::getPseudoTrack(chH_cand).ndof() > 0
1963                 ? sp.scale(candFunc::getPseudoTrack(chH_cand).ndof(), dnn::pfCand_chHad_track_ndof - PFchH_index_offset)
1964                 : 0;
1965       }
1966       float hcal_fraction = candFunc::getHCalFraction(chH_cand, disable_hcalFraction_workaround_);
1967       get(dnn::pfCand_chHad_hcalFraction) =
1968           sp.scale(hcal_fraction, dnn::pfCand_chHad_hcalFraction - PFchH_index_offset);
1969       get(dnn::pfCand_chHad_rawCaloFraction) =
1970           sp.scale(candFunc::getRawCaloFraction(chH_cand), dnn::pfCand_chHad_rawCaloFraction - PFchH_index_offset);
1971     }
1972     if (valid_nH) {
1973       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFnH, is_inner));
1974       size_t index_nH = cell_map.at(CellObjectType::PfCand_neutralHadron);
1975       const auto& nH_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_nH));
1976 
1977       get(dnn::pfCand_nHad_valid) = sp.scale(valid_nH, dnn::pfCand_nHad_valid - PFnH_index_offset);
1978       get(dnn::pfCand_nHad_rel_pt) =
1979           sp.scale(nH_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_nHad_rel_pt - PFnH_index_offset);
1980       get(dnn::pfCand_nHad_deta) =
1981           sp.scale(nH_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_nHad_deta - PFnH_index_offset);
1982       get(dnn::pfCand_nHad_dphi) =
1983           sp.scale(dPhi(tau.polarP4(), nH_cand.polarP4()), dnn::pfCand_nHad_dphi - PFnH_index_offset);
1984       get(dnn::pfCand_nHad_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(nH_cand, 0.9798355f),
1985                                                               dnn::pfCand_nHad_puppiWeight - PFnH_index_offset)
1986                                                    : sp.scale(candFunc::getPuppiWeight(nH_cand, 0.7813260f),
1987                                                               dnn::pfCand_nHad_puppiWeight - PFnH_index_offset);
1988       get(dnn::pfCand_nHad_puppiWeightNoLep) = is_inner
1989                                                    ? sp.scale(candFunc::getPuppiWeightNoLep(nH_cand, 0.9046796f),
1990                                                               dnn::pfCand_nHad_puppiWeightNoLep - PFnH_index_offset)
1991                                                    : sp.scale(candFunc::getPuppiWeightNoLep(nH_cand, 0.6554860f),
1992                                                               dnn::pfCand_nHad_puppiWeightNoLep - PFnH_index_offset);
1993       float hcal_fraction = candFunc::getHCalFraction(nH_cand, disable_hcalFraction_workaround_);
1994       get(dnn::pfCand_nHad_hcalFraction) = sp.scale(hcal_fraction, dnn::pfCand_nHad_hcalFraction - PFnH_index_offset);
1995     }
1996   }
1997 
1998   static void calculateElectronClusterVars(const pat::Electron* ele, float& elecEe, float& elecEgamma) {
1999     if (ele) {
2000       elecEe = elecEgamma = 0;
2001       auto superCluster = ele->superCluster();
2002       if (superCluster.isNonnull() && superCluster.isAvailable() && superCluster->clusters().isNonnull() &&
2003           superCluster->clusters().isAvailable()) {
2004         for (auto iter = superCluster->clustersBegin(); iter != superCluster->clustersEnd(); ++iter) {
2005           const double energy = (*iter)->energy();
2006           if (iter == superCluster->clustersBegin())
2007             elecEe += energy;
2008           else
2009             elecEgamma += energy;
2010         }
2011       }
2012     } else {
2013       elecEe = elecEgamma = default_value;
2014     }
2015   }
2016 
2017   template <typename CandidateCollection, typename TauCastType>
2018   static void processSignalPFComponents(const TauCastType& tau,
2019                                         const CandidateCollection& candidates,
2020                                         LorentzVectorXYZ& p4_inner,
2021                                         LorentzVectorXYZ& p4_outer,
2022                                         float& pt_inner,
2023                                         float& dEta_inner,
2024                                         float& dPhi_inner,
2025                                         float& m_inner,
2026                                         float& pt_outer,
2027                                         float& dEta_outer,
2028                                         float& dPhi_outer,
2029                                         float& m_outer,
2030                                         float& n_inner,
2031                                         float& n_outer) {
2032     p4_inner = LorentzVectorXYZ(0, 0, 0, 0);
2033     p4_outer = LorentzVectorXYZ(0, 0, 0, 0);
2034     n_inner = 0;
2035     n_outer = 0;
2036 
2037     const double innerSigCone_radius = getInnerSignalConeRadius(tau.pt());
2038     for (const auto& cand : candidates) {
2039       const double dR = reco::deltaR(cand->p4(), tau.leadChargedHadrCand()->p4());
2040       const bool isInside_innerSigCone = dR < innerSigCone_radius;
2041       if (isInside_innerSigCone) {
2042         p4_inner += cand->p4();
2043         ++n_inner;
2044       } else {
2045         p4_outer += cand->p4();
2046         ++n_outer;
2047       }
2048     }
2049 
2050     pt_inner = n_inner != 0 ? p4_inner.Pt() : default_value;
2051     dEta_inner = n_inner != 0 ? dEta(p4_inner, tau.p4()) : default_value;
2052     dPhi_inner = n_inner != 0 ? dPhi(p4_inner, tau.p4()) : default_value;
2053     m_inner = n_inner != 0 ? p4_inner.mass() : default_value;
2054 
2055     pt_outer = n_outer != 0 ? p4_outer.Pt() : default_value;
2056     dEta_outer = n_outer != 0 ? dEta(p4_outer, tau.p4()) : default_value;
2057     dPhi_outer = n_outer != 0 ? dPhi(p4_outer, tau.p4()) : default_value;
2058     m_outer = n_outer != 0 ? p4_outer.mass() : default_value;
2059   }
2060 
2061   template <typename CandidateCollection, typename TauCastType>
2062   static void processIsolationPFComponents(const TauCastType& tau,
2063                                            const CandidateCollection& candidates,
2064                                            LorentzVectorXYZ& p4,
2065                                            float& pt,
2066                                            float& d_eta,
2067                                            float& d_phi,
2068                                            float& m,
2069                                            float& n) {
2070     p4 = LorentzVectorXYZ(0, 0, 0, 0);
2071     n = 0;
2072 
2073     for (const auto& cand : candidates) {
2074       p4 += cand->p4();
2075       ++n;
2076     }
2077 
2078     pt = n != 0 ? p4.Pt() : default_value;
2079     d_eta = n != 0 ? dEta(p4, tau.p4()) : default_value;
2080     d_phi = n != 0 ? dPhi(p4, tau.p4()) : default_value;
2081     m = n != 0 ? p4.mass() : default_value;
2082   }
2083 
2084   static double getInnerSignalConeRadius(double pt) {
2085     static constexpr double min_pt = 30., min_radius = 0.05, cone_opening_coef = 3.;
2086     // This is equivalent of the original formula (std::max(std::min(0.1, 3.0/pt), 0.05)
2087     return std::max(cone_opening_coef / std::max(pt, min_pt), min_radius);
2088   }
2089 
2090   // Copied from https://github.com/cms-sw/cmssw/blob/CMSSW_9_4_X/RecoTauTag/RecoTau/plugins/PATTauDiscriminationByMVAIsolationRun2.cc#L218
2091   template <typename TauCastType>
2092   static bool calculateGottfriedJacksonAngleDifference(const TauCastType& tau,
2093                                                        const size_t tau_index,
2094                                                        double& gj_diff,
2095                                                        TauFunc tau_funcs) {
2096     if (tau_funcs.getHasSecondaryVertex(tau, tau_index)) {
2097       static constexpr double mTau = 1.77682;
2098       const double mAOne = tau.p4().M();
2099       const double pAOneMag = tau.p();
2100       const double argumentThetaGJmax = (std::pow(mTau, 2) - std::pow(mAOne, 2)) / (2 * mTau * pAOneMag);
2101       const double argumentThetaGJmeasured = tau.p4().Vect().Dot(tau_funcs.getFlightLength(tau, tau_index)) /
2102                                              (pAOneMag * tau_funcs.getFlightLength(tau, tau_index).R());
2103       if (std::abs(argumentThetaGJmax) <= 1. && std::abs(argumentThetaGJmeasured) <= 1.) {
2104         double thetaGJmax = std::asin(argumentThetaGJmax);
2105         double thetaGJmeasured = std::acos(argumentThetaGJmeasured);
2106         gj_diff = thetaGJmeasured - thetaGJmax;
2107         return true;
2108       }
2109     }
2110     return false;
2111   }
2112 
2113   template <typename TauCastType>
2114   static float calculateGottfriedJacksonAngleDifference(const TauCastType& tau,
2115                                                         const size_t tau_index,
2116                                                         TauFunc tau_funcs) {
2117     double gj_diff;
2118     if (calculateGottfriedJacksonAngleDifference(tau, tau_index, gj_diff, tau_funcs))
2119       return static_cast<float>(gj_diff);
2120     return default_value;
2121   }
2122 
2123   static bool isInEcalCrack(double eta) {
2124     const double abs_eta = std::abs(eta);
2125     return abs_eta > 1.46 && abs_eta < 1.558;
2126   }
2127 
2128   template <typename TauCastType>
2129   static const pat::Electron* findMatchedElectron(const TauCastType& tau,
2130                                                   const std::vector<pat::Electron>* electrons,
2131                                                   double deltaR) {
2132     const double dR2 = deltaR * deltaR;
2133     const pat::Electron* matched_ele = nullptr;
2134     for (const auto& ele : *electrons) {
2135       if (reco::deltaR2(tau.p4(), ele.p4()) < dR2 && (!matched_ele || matched_ele->pt() < ele.pt())) {
2136         matched_ele = &ele;
2137       }
2138     }
2139     return matched_ele;
2140   }
2141 
2142 protected:
2143   edm::EDGetTokenT<TauCollection> tausToken_;
2144   edm::EDGetTokenT<CandidateCollection> pfcandToken_;
2145   edm::EDGetTokenT<reco::VertexCollection> vtxToken_;
2146   std::map<std::string, WPList> workingPoints_;
2147   const bool is_online_;
2148   IDOutputCollection idoutputs_;
2149 
2150   const std::map<BasicDiscriminator, std::string> stringFromDiscriminator_{
2151       {BasicDiscriminator::ChargedIsoPtSum, "ChargedIsoPtSum"},
2152       {BasicDiscriminator::NeutralIsoPtSum, "NeutralIsoPtSum"},
2153       {BasicDiscriminator::NeutralIsoPtSumWeight, "NeutralIsoPtSumWeight"},
2154       {BasicDiscriminator::FootprintCorrection, "TauFootprintCorrection"},
2155       {BasicDiscriminator::PhotonPtSumOutsideSignalCone, "PhotonPtSumOutsideSignalCone"},
2156       {BasicDiscriminator::PUcorrPtSum, "PUcorrPtSum"}};
2157   const std::vector<BasicDiscriminator> requiredBasicDiscriminators_{BasicDiscriminator::ChargedIsoPtSum,
2158                                                                      BasicDiscriminator::NeutralIsoPtSum,
2159                                                                      BasicDiscriminator::NeutralIsoPtSumWeight,
2160                                                                      BasicDiscriminator::PhotonPtSumOutsideSignalCone,
2161                                                                      BasicDiscriminator::PUcorrPtSum};
2162   const std::vector<BasicDiscriminator> requiredBasicDiscriminatorsdR03_{
2163       BasicDiscriminator::ChargedIsoPtSum,
2164       BasicDiscriminator::NeutralIsoPtSum,
2165       BasicDiscriminator::NeutralIsoPtSumWeight,
2166       BasicDiscriminator::PhotonPtSumOutsideSignalCone,
2167       BasicDiscriminator::FootprintCorrection};
2168 
2169   edm::EDGetTokenT<std::vector<pat::Electron>> electrons_token_;
2170   edm::EDGetTokenT<std::vector<pat::Muon>> muons_token_;
2171   edm::EDGetTokenT<double> rho_token_;
2172   edm::EDGetTokenT<reco::TauDiscriminatorContainer> basicTauDiscriminators_inputToken_;
2173   edm::EDGetTokenT<reco::TauDiscriminatorContainer> basicTauDiscriminatorsdR03_inputToken_;
2174   edm::EDGetTokenT<edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>>
2175       pfTauTransverseImpactParameters_token_;
2176   std::string input_layer_, output_layer_;
2177   const unsigned year_;
2178   const unsigned version_;
2179   const unsigned sub_version_;
2180   const int debug_level;
2181   const bool disable_dxy_pca_;
2182   const bool disable_hcalFraction_workaround_;
2183   const bool disable_CellIndex_workaround_;
2184   const std::map<std::pair<deep_tau::Scaling::FeatureT, bool>, deep_tau::Scaling::ScalingParams>* scalingParamsMap_;
2185   const bool save_inputs_;
2186   std::ofstream* json_file_;
2187   bool is_first_block_;
2188   int file_counter_;
2189   std::vector<int> tauInputs_indices_;
2190 
2191   //boolean to check if discriminator indices are already mapped
2192   bool discrIndicesMapped_ = false;
2193   std::map<BasicDiscriminator, size_t> basicDiscrIndexMap_;
2194   std::map<BasicDiscriminator, size_t> basicDiscrdR03IndexMap_;
2195 };
2196 
2197 #endif