Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-10-16 05:06:36

0001 #ifndef DeepTauIdBase_H
0002 #define DeepTauIdBase_H
0003 
0004 /*
0005  * \class DeepTauIdBase
0006  *
0007  * Base class for DeepTauId producers: DeepTauId and DeepTauIdSONIC
0008  *
0009  */
0010 
0011 #include "DataFormats/PatCandidates/interface/Electron.h"
0012 #include "DataFormats/PatCandidates/interface/Muon.h"
0013 #include "DataFormats/PatCandidates/interface/Tau.h"
0014 #include "DataFormats/TauReco/interface/TauDiscriminatorContainer.h"
0015 #include "DataFormats/TauReco/interface/PFTauDiscriminator.h"
0016 #include "DataFormats/PatCandidates/interface/PATTauDiscriminator.h"
0017 #include "DataFormats/Common/interface/View.h"
0018 #include "DataFormats/Common/interface/RefToBase.h"
0019 #include "DataFormats/Provenance/interface/ProductProvenance.h"
0020 #include "DataFormats/Provenance/interface/ProcessHistoryID.h"
0021 #include "FWCore/Common/interface/Provenance.h"
0022 #include "FWCore/Framework/interface/Event.h"
0023 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0024 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0025 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0026 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0027 #include "RecoTauTag/RecoTau/interface/PFRecoTauClusterVariables.h"
0028 #include "RecoTauTag/RecoTau/interface/DeepTauScaling.h"
0029 #include "RecoTauTag/RecoTau/interface/TauWPThreshold.h"
0030 #include "DataFormats/TauReco/interface/PFTauTransverseImpactParameterAssociation.h"
0031 #include "CommonTools/Utils/interface/StringObjectFunction.h"
0032 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0033 #include "tensorflow/core/util/memmapped_file_system.h"
0034 
0035 #include <Math/VectorUtil.h>
0036 #include <map>
0037 #include <fstream>
0038 #include "oneapi/tbb/concurrent_unordered_set.h"
0039 
0040 namespace deep_tau {
0041   enum BasicDiscriminator {
0042     ChargedIsoPtSum,
0043     NeutralIsoPtSum,
0044     NeutralIsoPtSumWeight,
0045     FootprintCorrection,
0046     PhotonPtSumOutsideSignalCone,
0047     PUcorrPtSum
0048   };
0049 
0050   constexpr int NumberOfOutputs = 4;
0051 }  // namespace deep_tau
0052 
0053 namespace {
0054 
0055   namespace dnn_inputs_v2 {
0056     constexpr int number_of_inner_cell = 11;
0057     constexpr int number_of_outer_cell = 21;
0058     constexpr int number_of_conv_features = 64;
0059     namespace TauBlockInputs {
0060       enum vars {
0061         rho = 0,
0062         tau_pt,
0063         tau_eta,
0064         tau_phi,
0065         tau_mass,
0066         tau_E_over_pt,
0067         tau_charge,
0068         tau_n_charged_prongs,
0069         tau_n_neutral_prongs,
0070         chargedIsoPtSum,
0071         chargedIsoPtSumdR03_over_dR05,
0072         footprintCorrection,
0073         neutralIsoPtSum,
0074         neutralIsoPtSumWeight_over_neutralIsoPtSum,
0075         neutralIsoPtSumWeightdR03_over_neutralIsoPtSum,
0076         neutralIsoPtSumdR03_over_dR05,
0077         photonPtSumOutsideSignalCone,
0078         puCorrPtSum,
0079         tau_dxy_pca_x,
0080         tau_dxy_pca_y,
0081         tau_dxy_pca_z,
0082         tau_dxy_valid,
0083         tau_dxy,
0084         tau_dxy_sig,
0085         tau_ip3d_valid,
0086         tau_ip3d,
0087         tau_ip3d_sig,
0088         tau_dz,
0089         tau_dz_sig_valid,
0090         tau_dz_sig,
0091         tau_flightLength_x,
0092         tau_flightLength_y,
0093         tau_flightLength_z,
0094         tau_flightLength_sig,
0095         tau_pt_weighted_deta_strip,
0096         tau_pt_weighted_dphi_strip,
0097         tau_pt_weighted_dr_signal,
0098         tau_pt_weighted_dr_iso,
0099         tau_leadingTrackNormChi2,
0100         tau_e_ratio_valid,
0101         tau_e_ratio,
0102         tau_gj_angle_diff_valid,
0103         tau_gj_angle_diff,
0104         tau_n_photons,
0105         tau_emFraction,
0106         tau_inside_ecal_crack,
0107         leadChargedCand_etaAtEcalEntrance_minus_tau_eta,
0108         NumberOfInputs
0109       };
0110       inline std::vector<int> varsToDrop = {
0111           tau_phi, tau_dxy_pca_x, tau_dxy_pca_y, tau_dxy_pca_z};  // indices of vars to be dropped in the full var enum
0112     }  // namespace TauBlockInputs
0113 
0114     namespace EgammaBlockInputs {
0115       enum vars {
0116         rho = 0,
0117         tau_pt,
0118         tau_eta,
0119         tau_inside_ecal_crack,
0120         pfCand_ele_valid,
0121         pfCand_ele_rel_pt,
0122         pfCand_ele_deta,
0123         pfCand_ele_dphi,
0124         pfCand_ele_pvAssociationQuality,
0125         pfCand_ele_puppiWeight,
0126         pfCand_ele_charge,
0127         pfCand_ele_lostInnerHits,
0128         pfCand_ele_numberOfPixelHits,
0129         pfCand_ele_vertex_dx,
0130         pfCand_ele_vertex_dy,
0131         pfCand_ele_vertex_dz,
0132         pfCand_ele_vertex_dx_tauFL,
0133         pfCand_ele_vertex_dy_tauFL,
0134         pfCand_ele_vertex_dz_tauFL,
0135         pfCand_ele_hasTrackDetails,
0136         pfCand_ele_dxy,
0137         pfCand_ele_dxy_sig,
0138         pfCand_ele_dz,
0139         pfCand_ele_dz_sig,
0140         pfCand_ele_track_chi2_ndof,
0141         pfCand_ele_track_ndof,
0142         ele_valid,
0143         ele_rel_pt,
0144         ele_deta,
0145         ele_dphi,
0146         ele_cc_valid,
0147         ele_cc_ele_rel_energy,
0148         ele_cc_gamma_rel_energy,
0149         ele_cc_n_gamma,
0150         ele_rel_trackMomentumAtVtx,
0151         ele_rel_trackMomentumAtCalo,
0152         ele_rel_trackMomentumOut,
0153         ele_rel_trackMomentumAtEleClus,
0154         ele_rel_trackMomentumAtVtxWithConstraint,
0155         ele_rel_ecalEnergy,
0156         ele_ecalEnergy_sig,
0157         ele_eSuperClusterOverP,
0158         ele_eSeedClusterOverP,
0159         ele_eSeedClusterOverPout,
0160         ele_eEleClusterOverPout,
0161         ele_deltaEtaSuperClusterTrackAtVtx,
0162         ele_deltaEtaSeedClusterTrackAtCalo,
0163         ele_deltaEtaEleClusterTrackAtCalo,
0164         ele_deltaPhiEleClusterTrackAtCalo,
0165         ele_deltaPhiSuperClusterTrackAtVtx,
0166         ele_deltaPhiSeedClusterTrackAtCalo,
0167         ele_mvaInput_earlyBrem,
0168         ele_mvaInput_lateBrem,
0169         ele_mvaInput_sigmaEtaEta,
0170         ele_mvaInput_hadEnergy,
0171         ele_mvaInput_deltaEta,
0172         ele_gsfTrack_normalizedChi2,
0173         ele_gsfTrack_numberOfValidHits,
0174         ele_rel_gsfTrack_pt,
0175         ele_gsfTrack_pt_sig,
0176         ele_has_closestCtfTrack,
0177         ele_closestCtfTrack_normalizedChi2,
0178         ele_closestCtfTrack_numberOfValidHits,
0179         pfCand_gamma_valid,
0180         pfCand_gamma_rel_pt,
0181         pfCand_gamma_deta,
0182         pfCand_gamma_dphi,
0183         pfCand_gamma_pvAssociationQuality,
0184         pfCand_gamma_fromPV,
0185         pfCand_gamma_puppiWeight,
0186         pfCand_gamma_puppiWeightNoLep,
0187         pfCand_gamma_lostInnerHits,
0188         pfCand_gamma_numberOfPixelHits,
0189         pfCand_gamma_vertex_dx,
0190         pfCand_gamma_vertex_dy,
0191         pfCand_gamma_vertex_dz,
0192         pfCand_gamma_vertex_dx_tauFL,
0193         pfCand_gamma_vertex_dy_tauFL,
0194         pfCand_gamma_vertex_dz_tauFL,
0195         pfCand_gamma_hasTrackDetails,
0196         pfCand_gamma_dxy,
0197         pfCand_gamma_dxy_sig,
0198         pfCand_gamma_dz,
0199         pfCand_gamma_dz_sig,
0200         pfCand_gamma_track_chi2_ndof,
0201         pfCand_gamma_track_ndof,
0202         NumberOfInputs
0203       };
0204     }
0205 
0206     namespace MuonBlockInputs {
0207       enum vars {
0208         rho = 0,
0209         tau_pt,
0210         tau_eta,
0211         tau_inside_ecal_crack,
0212         pfCand_muon_valid,
0213         pfCand_muon_rel_pt,
0214         pfCand_muon_deta,
0215         pfCand_muon_dphi,
0216         pfCand_muon_pvAssociationQuality,
0217         pfCand_muon_fromPV,
0218         pfCand_muon_puppiWeight,
0219         pfCand_muon_charge,
0220         pfCand_muon_lostInnerHits,
0221         pfCand_muon_numberOfPixelHits,
0222         pfCand_muon_vertex_dx,
0223         pfCand_muon_vertex_dy,
0224         pfCand_muon_vertex_dz,
0225         pfCand_muon_vertex_dx_tauFL,
0226         pfCand_muon_vertex_dy_tauFL,
0227         pfCand_muon_vertex_dz_tauFL,
0228         pfCand_muon_hasTrackDetails,
0229         pfCand_muon_dxy,
0230         pfCand_muon_dxy_sig,
0231         pfCand_muon_dz,
0232         pfCand_muon_dz_sig,
0233         pfCand_muon_track_chi2_ndof,
0234         pfCand_muon_track_ndof,
0235         muon_valid,
0236         muon_rel_pt,
0237         muon_deta,
0238         muon_dphi,
0239         muon_dxy,
0240         muon_dxy_sig,
0241         muon_normalizedChi2_valid,
0242         muon_normalizedChi2,
0243         muon_numberOfValidHits,
0244         muon_segmentCompatibility,
0245         muon_caloCompatibility,
0246         muon_pfEcalEnergy_valid,
0247         muon_rel_pfEcalEnergy,
0248         muon_n_matches_DT_1,
0249         muon_n_matches_DT_2,
0250         muon_n_matches_DT_3,
0251         muon_n_matches_DT_4,
0252         muon_n_matches_CSC_1,
0253         muon_n_matches_CSC_2,
0254         muon_n_matches_CSC_3,
0255         muon_n_matches_CSC_4,
0256         muon_n_matches_RPC_1,
0257         muon_n_matches_RPC_2,
0258         muon_n_matches_RPC_3,
0259         muon_n_matches_RPC_4,
0260         muon_n_hits_DT_1,
0261         muon_n_hits_DT_2,
0262         muon_n_hits_DT_3,
0263         muon_n_hits_DT_4,
0264         muon_n_hits_CSC_1,
0265         muon_n_hits_CSC_2,
0266         muon_n_hits_CSC_3,
0267         muon_n_hits_CSC_4,
0268         muon_n_hits_RPC_1,
0269         muon_n_hits_RPC_2,
0270         muon_n_hits_RPC_3,
0271         muon_n_hits_RPC_4,
0272         NumberOfInputs
0273       };
0274     }
0275 
0276     namespace HadronBlockInputs {
0277       enum vars {
0278         rho = 0,
0279         tau_pt,
0280         tau_eta,
0281         tau_inside_ecal_crack,
0282         pfCand_chHad_valid,
0283         pfCand_chHad_rel_pt,
0284         pfCand_chHad_deta,
0285         pfCand_chHad_dphi,
0286         pfCand_chHad_leadChargedHadrCand,
0287         pfCand_chHad_pvAssociationQuality,
0288         pfCand_chHad_fromPV,
0289         pfCand_chHad_puppiWeight,
0290         pfCand_chHad_puppiWeightNoLep,
0291         pfCand_chHad_charge,
0292         pfCand_chHad_lostInnerHits,
0293         pfCand_chHad_numberOfPixelHits,
0294         pfCand_chHad_vertex_dx,
0295         pfCand_chHad_vertex_dy,
0296         pfCand_chHad_vertex_dz,
0297         pfCand_chHad_vertex_dx_tauFL,
0298         pfCand_chHad_vertex_dy_tauFL,
0299         pfCand_chHad_vertex_dz_tauFL,
0300         pfCand_chHad_hasTrackDetails,
0301         pfCand_chHad_dxy,
0302         pfCand_chHad_dxy_sig,
0303         pfCand_chHad_dz,
0304         pfCand_chHad_dz_sig,
0305         pfCand_chHad_track_chi2_ndof,
0306         pfCand_chHad_track_ndof,
0307         pfCand_chHad_hcalFraction,
0308         pfCand_chHad_rawCaloFraction,
0309         pfCand_nHad_valid,
0310         pfCand_nHad_rel_pt,
0311         pfCand_nHad_deta,
0312         pfCand_nHad_dphi,
0313         pfCand_nHad_puppiWeight,
0314         pfCand_nHad_puppiWeightNoLep,
0315         pfCand_nHad_hcalFraction,
0316         NumberOfInputs
0317       };
0318     }
0319   }  // namespace dnn_inputs_v2
0320 
0321   inline float getTauID(const pat::Tau& tau,
0322                         const std::string& tauID,
0323                         float default_value = -999.,
0324                         bool assert_input = true) {
0325     static tbb::concurrent_unordered_set<std::string> isFirstWarning;
0326     if (tau.isTauIDAvailable(tauID)) {
0327       return tau.tauID(tauID);
0328     } else {
0329       if (assert_input) {
0330         throw cms::Exception("DeepTauId")
0331             << "Exception in <getTauID>: No tauID '" << tauID << "' available in pat::Tau given as function argument.";
0332       }
0333       if (isFirstWarning.insert(tauID).second) {
0334         edm::LogWarning("DeepTauID") << "Warning in <getTauID>: No tauID '" << tauID
0335                                      << "' available in pat::Tau given as function argument."
0336                                      << " Using default_value = " << default_value << " instead." << std::endl;
0337       }
0338       return default_value;
0339     }
0340   }
0341 
0342   struct TauFunc {
0343     const reco::TauDiscriminatorContainer* basicTauDiscriminatorCollection;
0344     const reco::TauDiscriminatorContainer* basicTauDiscriminatordR03Collection;
0345     const edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>*
0346         pfTauTransverseImpactParameters;
0347 
0348     using BasicDiscr = deep_tau::BasicDiscriminator;
0349     std::map<BasicDiscr, size_t> indexMap;
0350     std::map<BasicDiscr, size_t> indexMapdR03;
0351 
0352     const float getChargedIsoPtSum(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0353       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::ChargedIsoPtSum));
0354     }
0355     const float getChargedIsoPtSum(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0356       return getTauID(tau, "chargedIsoPtSum");
0357     }
0358     const float getChargedIsoPtSumdR03(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0359       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(indexMapdR03.at(BasicDiscr::ChargedIsoPtSum));
0360     }
0361     const float getChargedIsoPtSumdR03(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0362       return getTauID(tau, "chargedIsoPtSumdR03");
0363     }
0364     const float getFootprintCorrectiondR03(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0365       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(
0366           indexMapdR03.at(BasicDiscr::FootprintCorrection));
0367     }
0368     const float getFootprintCorrectiondR03(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0369       return getTauID(tau, "footprintCorrectiondR03");
0370     }
0371     const float getFootprintCorrection(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0372       return getTauID(tau, "footprintCorrection");
0373     }
0374     const float getNeutralIsoPtSum(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0375       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::NeutralIsoPtSum));
0376     }
0377     const float getNeutralIsoPtSum(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0378       return getTauID(tau, "neutralIsoPtSum");
0379     }
0380     const float getNeutralIsoPtSumdR03(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0381       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(indexMapdR03.at(BasicDiscr::NeutralIsoPtSum));
0382     }
0383     const float getNeutralIsoPtSumdR03(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0384       return getTauID(tau, "neutralIsoPtSumdR03");
0385     }
0386     const float getNeutralIsoPtSumWeight(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0387       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::NeutralIsoPtSumWeight));
0388     }
0389     const float getNeutralIsoPtSumWeight(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0390       return getTauID(tau, "neutralIsoPtSumWeight");
0391     }
0392     const float getNeutralIsoPtSumdR03Weight(const reco::PFTau& tau,
0393                                              const edm::RefToBase<reco::BaseTau> tau_ref) const {
0394       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(
0395           indexMapdR03.at(BasicDiscr::NeutralIsoPtSumWeight));
0396     }
0397     const float getNeutralIsoPtSumdR03Weight(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0398       return getTauID(tau, "neutralIsoPtSumWeightdR03");
0399     }
0400     const float getPhotonPtSumOutsideSignalCone(const reco::PFTau& tau,
0401                                                 const edm::RefToBase<reco::BaseTau> tau_ref) const {
0402       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(
0403           indexMap.at(BasicDiscr::PhotonPtSumOutsideSignalCone));
0404     }
0405     const float getPhotonPtSumOutsideSignalCone(const pat::Tau& tau,
0406                                                 const edm::RefToBase<reco::BaseTau> tau_ref) const {
0407       return getTauID(tau, "photonPtSumOutsideSignalCone");
0408     }
0409     const float getPhotonPtSumOutsideSignalConedR03(const reco::PFTau& tau,
0410                                                     const edm::RefToBase<reco::BaseTau> tau_ref) const {
0411       return (*basicTauDiscriminatordR03Collection)[tau_ref].rawValues.at(
0412           indexMapdR03.at(BasicDiscr::PhotonPtSumOutsideSignalCone));
0413     }
0414     const float getPhotonPtSumOutsideSignalConedR03(const pat::Tau& tau,
0415                                                     const edm::RefToBase<reco::BaseTau> tau_ref) const {
0416       return getTauID(tau, "photonPtSumOutsideSignalConedR03");
0417     }
0418     const float getPuCorrPtSum(const reco::PFTau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0419       return (*basicTauDiscriminatorCollection)[tau_ref].rawValues.at(indexMap.at(BasicDiscr::PUcorrPtSum));
0420     }
0421     const float getPuCorrPtSum(const pat::Tau& tau, const edm::RefToBase<reco::BaseTau> tau_ref) const {
0422       return getTauID(tau, "puCorrPtSum");
0423     }
0424 
0425     auto getdxyPCA(const reco::PFTau& tau, const size_t tau_index) const {
0426       return pfTauTransverseImpactParameters->value(tau_index)->dxy_PCA();
0427     }
0428     auto getdxyPCA(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy_PCA(); }
0429     auto getdxy(const reco::PFTau& tau, const size_t tau_index) const {
0430       return pfTauTransverseImpactParameters->value(tau_index)->dxy();
0431     }
0432     auto getdxy(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy(); }
0433     auto getdxyError(const reco::PFTau& tau, const size_t tau_index) const {
0434       return pfTauTransverseImpactParameters->value(tau_index)->dxy_error();
0435     }
0436     auto getdxyError(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy_error(); }
0437     auto getdxySig(const reco::PFTau& tau, const size_t tau_index) const {
0438       return pfTauTransverseImpactParameters->value(tau_index)->dxy_Sig();
0439     }
0440     auto getdxySig(const pat::Tau& tau, const size_t tau_index) const { return tau.dxy_Sig(); }
0441     auto getip3d(const reco::PFTau& tau, const size_t tau_index) const {
0442       return pfTauTransverseImpactParameters->value(tau_index)->ip3d();
0443     }
0444     auto getip3d(const pat::Tau& tau, const size_t tau_index) const { return tau.ip3d(); }
0445     auto getip3dError(const reco::PFTau& tau, const size_t tau_index) const {
0446       return pfTauTransverseImpactParameters->value(tau_index)->ip3d_error();
0447     }
0448     auto getip3dError(const pat::Tau& tau, const size_t tau_index) const { return tau.ip3d_error(); }
0449     auto getip3dSig(const reco::PFTau& tau, const size_t tau_index) const {
0450       return pfTauTransverseImpactParameters->value(tau_index)->ip3d_Sig();
0451     }
0452     auto getip3dSig(const pat::Tau& tau, const size_t tau_index) const { return tau.ip3d_Sig(); }
0453     auto getHasSecondaryVertex(const reco::PFTau& tau, const size_t tau_index) const {
0454       return pfTauTransverseImpactParameters->value(tau_index)->hasSecondaryVertex();
0455     }
0456     auto getHasSecondaryVertex(const pat::Tau& tau, const size_t tau_index) const { return tau.hasSecondaryVertex(); }
0457     auto getFlightLength(const reco::PFTau& tau, const size_t tau_index) const {
0458       return pfTauTransverseImpactParameters->value(tau_index)->flightLength();
0459     }
0460     auto getFlightLength(const pat::Tau& tau, const size_t tau_index) const { return tau.flightLength(); }
0461     auto getFlightLengthSig(const reco::PFTau& tau, const size_t tau_index) const {
0462       return pfTauTransverseImpactParameters->value(tau_index)->flightLengthSig();
0463     }
0464     auto getFlightLengthSig(const pat::Tau& tau, const size_t tau_index) const { return tau.flightLengthSig(); }
0465 
0466     auto getLeadingTrackNormChi2(const reco::PFTau& tau) { return reco::tau::lead_track_chi2(tau); }
0467     auto getLeadingTrackNormChi2(const pat::Tau& tau) { return tau.leadingTrackNormChi2(); }
0468     auto getEmFraction(const pat::Tau& tau) { return tau.emFraction_MVA(); }
0469     auto getEmFraction(const reco::PFTau& tau) { return tau.emFraction(); }
0470     auto getEtaAtEcalEntrance(const pat::Tau& tau) { return tau.etaAtEcalEntranceLeadChargedCand(); }
0471     auto getEtaAtEcalEntrance(const reco::PFTau& tau) {
0472       return tau.leadPFChargedHadrCand()->positionAtECALEntrance().eta();
0473     }
0474     auto getEcalEnergyLeadingChargedHadr(const reco::PFTau& tau) { return tau.leadPFChargedHadrCand()->ecalEnergy(); }
0475     auto getEcalEnergyLeadingChargedHadr(const pat::Tau& tau) { return tau.ecalEnergyLeadChargedHadrCand(); }
0476     auto getHcalEnergyLeadingChargedHadr(const reco::PFTau& tau) { return tau.leadPFChargedHadrCand()->hcalEnergy(); }
0477     auto getHcalEnergyLeadingChargedHadr(const pat::Tau& tau) { return tau.hcalEnergyLeadChargedHadrCand(); }
0478 
0479     template <typename PreDiscrType>
0480     bool passPrediscriminants(const PreDiscrType prediscriminants,
0481                               const size_t andPrediscriminants,
0482                               const edm::RefToBase<reco::BaseTau> tau_ref) {
0483       bool passesPrediscriminants = (andPrediscriminants ? 1 : 0);
0484       // check tau passes prediscriminants
0485       size_t nPrediscriminants = prediscriminants.size();
0486       for (size_t iDisc = 0; iDisc < nPrediscriminants; ++iDisc) {
0487         // current discriminant result for this tau
0488         double discResult = (*prediscriminants[iDisc].handle)[tau_ref];
0489         uint8_t thisPasses = (discResult > prediscriminants[iDisc].cut) ? 1 : 0;
0490 
0491         // if we are using the AND option, as soon as one fails,
0492         // the result is FAIL and we can quit looping.
0493         // if we are using the OR option as soon as one passes,
0494         // the result is pass and we can quit looping
0495 
0496         // truth table
0497         //        |   result (thisPasses)
0498         //        |     F     |     T
0499         //-----------------------------------
0500         // AND(T) | res=fails |  continue
0501         //        |  break    |
0502         //-----------------------------------
0503         // OR (F) |  continue | res=passes
0504         //        |           |  break
0505 
0506         if (thisPasses ^ andPrediscriminants)  //XOR
0507         {
0508           passesPrediscriminants = (andPrediscriminants ? 0 : 1);  //NOR
0509           break;
0510         }
0511       }
0512       return passesPrediscriminants;
0513     }
0514   };
0515 
0516   namespace candFunc {
0517     inline auto getTauDz(const reco::PFCandidate& cand) { return cand.bestTrack()->dz(); }
0518     inline auto getTauDz(const pat::PackedCandidate& cand) { return cand.dz(); }
0519     inline auto getTauDZSigValid(const reco::PFCandidate& cand) {
0520       return cand.bestTrack() != nullptr && std::isnormal(cand.bestTrack()->dz()) && std::isnormal(cand.dzError()) &&
0521              cand.dzError() > 0;
0522     }
0523     inline auto getTauDZSigValid(const pat::PackedCandidate& cand) {
0524       return cand.hasTrackDetails() && std::isnormal(cand.dz()) && std::isnormal(cand.dzError()) && cand.dzError() > 0;
0525     }
0526     inline auto getTauDxy(const reco::PFCandidate& cand) { return cand.bestTrack()->dxy(); }
0527     inline auto getTauDxy(const pat::PackedCandidate& cand) { return cand.dxy(); }
0528     inline auto getPvAssocationQuality(const reco::PFCandidate& cand) { return 0.7013f; }
0529     inline auto getPvAssocationQuality(const pat::PackedCandidate& cand) { return cand.pvAssociationQuality(); }
0530     inline auto getPuppiWeight(const reco::PFCandidate& cand, const float aod_value) { return aod_value; }
0531     inline auto getPuppiWeight(const pat::PackedCandidate& cand, const float aod_value) { return cand.puppiWeight(); }
0532     inline auto getPuppiWeightNoLep(const reco::PFCandidate& cand, const float aod_value) { return aod_value; }
0533     inline auto getPuppiWeightNoLep(const pat::PackedCandidate& cand, const float aod_value) {
0534       return cand.puppiWeightNoLep();
0535     }
0536     inline auto getLostInnerHits(const reco::PFCandidate& cand, float default_value) {
0537       return cand.bestTrack() != nullptr
0538                  ? cand.bestTrack()->hitPattern().numberOfLostHits(reco::HitPattern::MISSING_INNER_HITS)
0539                  : default_value;
0540     }
0541     inline auto getLostInnerHits(const pat::PackedCandidate& cand, float default_value) { return cand.lostInnerHits(); }
0542     inline auto getNumberOfPixelHits(const reco::PFCandidate& cand, float default_value) {
0543       return cand.bestTrack() != nullptr
0544                  ? cand.bestTrack()->hitPattern().numberOfLostHits(reco::HitPattern::MISSING_INNER_HITS)
0545                  : default_value;
0546     }
0547     inline auto getNumberOfPixelHits(const pat::PackedCandidate& cand, float default_value) {
0548       return cand.numberOfPixelHits();
0549     }
0550     inline auto getHasTrackDetails(const reco::PFCandidate& cand) { return cand.bestTrack() != nullptr; }
0551     inline auto getHasTrackDetails(const pat::PackedCandidate& cand) { return cand.hasTrackDetails(); }
0552     inline auto getPseudoTrack(const reco::PFCandidate& cand) { return *cand.bestTrack(); }
0553     inline auto getPseudoTrack(const pat::PackedCandidate& cand) { return cand.pseudoTrack(); }
0554     inline auto getFromPV(const reco::PFCandidate& cand) { return 0.9994f; }
0555     inline auto getFromPV(const pat::PackedCandidate& cand) { return cand.fromPV(); }
0556     inline auto getHCalFraction(const reco::PFCandidate& cand, bool disable_hcalFraction_workaround) {
0557       return cand.rawHcalEnergy() / (cand.rawHcalEnergy() + cand.rawEcalEnergy());
0558     }
0559     inline auto getHCalFraction(const pat::PackedCandidate& cand, bool disable_hcalFraction_workaround) {
0560       float hcal_fraction = 0.;
0561       if (disable_hcalFraction_workaround) {
0562         // CV: use consistent definition for pfCand_chHad_hcalFraction
0563         //     in DeepTauId.cc code and in TauMLTools/Production/plugins/TauTupleProducer.cc
0564         hcal_fraction = cand.hcalFraction();
0565       } else {
0566         // CV: backwards compatibility with DeepTau training v2p1 used during Run 2
0567         if (cand.pdgId() == 1 || cand.pdgId() == 130) {
0568           hcal_fraction = cand.hcalFraction();
0569         } else if (cand.isIsolatedChargedHadron()) {
0570           hcal_fraction = cand.rawHcalFraction();
0571         }
0572       }
0573       return hcal_fraction;
0574     }
0575     inline auto getRawCaloFraction(const reco::PFCandidate& cand) {
0576       return (cand.rawEcalEnergy() + cand.rawHcalEnergy()) / cand.energy();
0577     }
0578     inline auto getRawCaloFraction(const pat::PackedCandidate& cand) { return cand.rawCaloFraction(); }
0579   };  // namespace candFunc
0580 
0581   template <typename LVector1, typename LVector2>
0582   float dEta(const LVector1& p4, const LVector2& tau_p4) {
0583     return static_cast<float>(p4.eta() - tau_p4.eta());
0584   }
0585 
0586   template <typename LVector1, typename LVector2>
0587   float dPhi(const LVector1& p4_1, const LVector2& p4_2) {
0588     return static_cast<float>(reco::deltaPhi(p4_2.phi(), p4_1.phi()));
0589   }
0590 
0591   struct MuonHitMatchV2 {
0592     static constexpr size_t n_muon_stations = 4;
0593     static constexpr int first_station_id = 1;
0594     static constexpr int last_station_id = first_station_id + n_muon_stations - 1;
0595     using CountArray = std::array<unsigned, n_muon_stations>;
0596     using CountMap = std::map<int, CountArray>;
0597 
0598     const std::vector<int>& consideredSubdets() {
0599       static const std::vector<int> subdets = {MuonSubdetId::DT, MuonSubdetId::CSC, MuonSubdetId::RPC};
0600       return subdets;
0601     }
0602 
0603     const std::string& subdetName(int subdet) {
0604       static const std::map<int, std::string> subdet_names = {
0605           {MuonSubdetId::DT, "DT"}, {MuonSubdetId::CSC, "CSC"}, {MuonSubdetId::RPC, "RPC"}};
0606       if (!subdet_names.count(subdet))
0607         throw cms::Exception("MuonHitMatch") << "Subdet name for subdet id " << subdet << " not found.";
0608       return subdet_names.at(subdet);
0609     }
0610 
0611     size_t getStationIndex(int station, bool throw_exception) const {
0612       if (station < first_station_id || station > last_station_id) {
0613         if (throw_exception)
0614           throw cms::Exception("MuonHitMatch") << "Station id is out of range";
0615         return std::numeric_limits<size_t>::max();
0616       }
0617       return static_cast<size_t>(station - 1);
0618     }
0619 
0620     MuonHitMatchV2(const pat::Muon& muon) {
0621       for (int subdet : consideredSubdets()) {
0622         n_matches[subdet].fill(0);
0623         n_hits[subdet].fill(0);
0624       }
0625 
0626       countMatches(muon, n_matches);
0627       countHits(muon, n_hits);
0628     }
0629 
0630     void countMatches(const pat::Muon& muon, CountMap& n_matches) {
0631       for (const auto& segment : muon.matches()) {
0632         if (segment.segmentMatches.empty() && segment.rpcMatches.empty())
0633           continue;
0634         if (n_matches.count(segment.detector())) {
0635           const size_t station_index = getStationIndex(segment.station(), true);
0636           ++n_matches.at(segment.detector()).at(station_index);
0637         }
0638       }
0639     }
0640 
0641     void countHits(const pat::Muon& muon, CountMap& n_hits) {
0642       if (muon.outerTrack().isNonnull()) {
0643         const auto& hit_pattern = muon.outerTrack()->hitPattern();
0644         for (int hit_index = 0; hit_index < hit_pattern.numberOfAllHits(reco::HitPattern::TRACK_HITS); ++hit_index) {
0645           auto hit_id = hit_pattern.getHitPattern(reco::HitPattern::TRACK_HITS, hit_index);
0646           if (hit_id == 0)
0647             break;
0648           if (hit_pattern.muonHitFilter(hit_id) && (hit_pattern.getHitType(hit_id) == TrackingRecHit::valid ||
0649                                                     hit_pattern.getHitType(hit_id) == TrackingRecHit::bad)) {
0650             const size_t station_index = getStationIndex(hit_pattern.getMuonStation(hit_id), false);
0651             if (station_index < n_muon_stations) {
0652               CountArray* muon_n_hits = nullptr;
0653               if (hit_pattern.muonDTHitFilter(hit_id))
0654                 muon_n_hits = &n_hits.at(MuonSubdetId::DT);
0655               else if (hit_pattern.muonCSCHitFilter(hit_id))
0656                 muon_n_hits = &n_hits.at(MuonSubdetId::CSC);
0657               else if (hit_pattern.muonRPCHitFilter(hit_id))
0658                 muon_n_hits = &n_hits.at(MuonSubdetId::RPC);
0659 
0660               if (muon_n_hits)
0661                 ++muon_n_hits->at(station_index);
0662             }
0663           }
0664         }
0665       }
0666     }
0667 
0668     unsigned nMatches(int subdet, int station) const {
0669       if (!n_matches.count(subdet))
0670         throw cms::Exception("MuonHitMatch") << "Subdet " << subdet << " not found.";
0671       const size_t station_index = getStationIndex(station, true);
0672       return n_matches.at(subdet).at(station_index);
0673     }
0674 
0675     unsigned nHits(int subdet, int station) const {
0676       if (!n_hits.count(subdet))
0677         throw cms::Exception("MuonHitMatch") << "Subdet " << subdet << " not found.";
0678       const size_t station_index = getStationIndex(station, true);
0679       return n_hits.at(subdet).at(station_index);
0680     }
0681 
0682     unsigned countMuonStationsWithMatches(int first_station, int last_station) const {
0683       static const std::map<int, std::vector<bool>> masks = {
0684           {MuonSubdetId::DT, {false, false, false, false}},
0685           {MuonSubdetId::CSC, {true, false, false, false}},
0686           {MuonSubdetId::RPC, {false, false, false, false}},
0687       };
0688       const size_t first_station_index = getStationIndex(first_station, true);
0689       const size_t last_station_index = getStationIndex(last_station, true);
0690       unsigned cnt = 0;
0691       for (size_t n = first_station_index; n <= last_station_index; ++n) {
0692         for (const auto& match : n_matches) {
0693           if (!masks.at(match.first).at(n) && match.second.at(n) > 0)
0694             ++cnt;
0695         }
0696       }
0697       return cnt;
0698     }
0699 
0700     unsigned countMuonStationsWithHits(int first_station, int last_station) const {
0701       static const std::map<int, std::vector<bool>> masks = {
0702           {MuonSubdetId::DT, {false, false, false, false}},
0703           {MuonSubdetId::CSC, {false, false, false, false}},
0704           {MuonSubdetId::RPC, {false, false, false, false}},
0705       };
0706 
0707       const size_t first_station_index = getStationIndex(first_station, true);
0708       const size_t last_station_index = getStationIndex(last_station, true);
0709       unsigned cnt = 0;
0710       for (size_t n = first_station_index; n <= last_station_index; ++n) {
0711         for (const auto& hit : n_hits) {
0712           if (!masks.at(hit.first).at(n) && hit.second.at(n) > 0)
0713             ++cnt;
0714         }
0715       }
0716       return cnt;
0717     }
0718 
0719   private:
0720     CountMap n_matches, n_hits;
0721   };
0722 
0723   enum class CellObjectType {
0724     PfCand_electron,
0725     PfCand_muon,
0726     PfCand_chargedHadron,
0727     PfCand_neutralHadron,
0728     PfCand_gamma,
0729     Electron,
0730     Muon,
0731     Other
0732   };
0733 
0734   template <typename Object>
0735   inline CellObjectType GetCellObjectType(const Object&);
0736   template <>
0737   inline CellObjectType GetCellObjectType(const pat::Electron&) {
0738     return CellObjectType::Electron;
0739   }
0740   template <>
0741   inline CellObjectType GetCellObjectType(const pat::Muon&) {
0742     return CellObjectType::Muon;
0743   }
0744 
0745   template <>
0746   inline CellObjectType GetCellObjectType(reco::Candidate const& cand) {
0747     static const std::map<int, CellObjectType> obj_types = {{11, CellObjectType::PfCand_electron},
0748                                                             {13, CellObjectType::PfCand_muon},
0749                                                             {22, CellObjectType::PfCand_gamma},
0750                                                             {130, CellObjectType::PfCand_neutralHadron},
0751                                                             {211, CellObjectType::PfCand_chargedHadron}};
0752 
0753     auto iter = obj_types.find(std::abs(cand.pdgId()));
0754     if (iter == obj_types.end())
0755       return CellObjectType::Other;
0756     return iter->second;
0757   }
0758 
0759   using Cell = std::map<CellObjectType, size_t>;
0760   struct CellIndex {
0761     int eta, phi;
0762 
0763     bool operator<(const CellIndex& other) const {
0764       if (eta != other.eta)
0765         return eta < other.eta;
0766       return phi < other.phi;
0767     }
0768   };
0769 
0770   class CellGrid {
0771   public:
0772     using Map = std::map<CellIndex, Cell>;
0773     using const_iterator = Map::const_iterator;
0774 
0775     CellGrid(unsigned n_cells_eta,
0776              unsigned n_cells_phi,
0777              double cell_size_eta,
0778              double cell_size_phi,
0779              bool disable_CellIndex_workaround)
0780         : nCellsEta(n_cells_eta),
0781           nCellsPhi(n_cells_phi),
0782           nTotal(nCellsEta * nCellsPhi),
0783           cellSizeEta(cell_size_eta),
0784           cellSizePhi(cell_size_phi),
0785           disable_CellIndex_workaround_(disable_CellIndex_workaround) {
0786       if (nCellsEta % 2 != 1 || nCellsEta < 1)
0787         throw cms::Exception("DeepTauId") << "Invalid number of eta cells.";
0788       if (nCellsPhi % 2 != 1 || nCellsPhi < 1)
0789         throw cms::Exception("DeepTauId") << "Invalid number of phi cells.";
0790       if (cellSizeEta <= 0 || cellSizePhi <= 0)
0791         throw cms::Exception("DeepTauId") << "Invalid cell size.";
0792     }
0793 
0794     int maxEtaIndex() const { return static_cast<int>((nCellsEta - 1) / 2); }
0795     int maxPhiIndex() const { return static_cast<int>((nCellsPhi - 1) / 2); }
0796     double maxDeltaEta() const { return cellSizeEta * (0.5 + maxEtaIndex()); }
0797     double maxDeltaPhi() const { return cellSizePhi * (0.5 + maxPhiIndex()); }
0798     int getEtaTensorIndex(const CellIndex& cellIndex) const { return cellIndex.eta + maxEtaIndex(); }
0799     int getPhiTensorIndex(const CellIndex& cellIndex) const { return cellIndex.phi + maxPhiIndex(); }
0800 
0801     bool tryGetCellIndex(double deltaEta, double deltaPhi, CellIndex& cellIndex) const {
0802       const auto getCellIndex = [this](double x, double maxX, double size, int& index) {
0803         const double absX = std::abs(x);
0804         if (absX > maxX)
0805           return false;
0806         double absIndex;
0807         if (disable_CellIndex_workaround_) {
0808           // CV: use consistent definition for CellIndex
0809           //     in DeepTauId.cc code and new DeepTau trainings
0810           absIndex = std::floor(absX / size + 0.5);
0811         } else {
0812           // CV: backwards compatibility with DeepTau training v2p1 used during Run 2
0813           absIndex = std::floor(std::abs(absX / size - 0.5));
0814         }
0815         index = static_cast<int>(std::copysign(absIndex, x));
0816         return true;
0817       };
0818 
0819       return getCellIndex(deltaEta, maxDeltaEta(), cellSizeEta, cellIndex.eta) &&
0820              getCellIndex(deltaPhi, maxDeltaPhi(), cellSizePhi, cellIndex.phi);
0821     }
0822 
0823     size_t num_valid_cells() const { return cells.size(); }
0824     Cell& operator[](const CellIndex& cellIndex) { return cells[cellIndex]; }
0825     const Cell& at(const CellIndex& cellIndex) const { return cells.at(cellIndex); }
0826     size_t count(const CellIndex& cellIndex) const { return cells.count(cellIndex); }
0827     const_iterator find(const CellIndex& cellIndex) const { return cells.find(cellIndex); }
0828     const_iterator begin() const { return cells.begin(); }
0829     const_iterator end() const { return cells.end(); }
0830 
0831   public:
0832     const unsigned nCellsEta, nCellsPhi, nTotal;
0833     const double cellSizeEta, cellSizePhi;
0834 
0835   private:
0836     std::map<CellIndex, Cell> cells;
0837     const bool disable_CellIndex_workaround_;
0838   };
0839 }  // anonymous namespace
0840 
0841 template <class Producer>
0842 class DeepTauIdBase : public Producer {
0843 public:
0844   using TauDiscriminator = reco::TauDiscriminatorContainer;
0845   using TauCollection = edm::View<reco::BaseTau>;
0846   using CandidateCollection = edm::View<reco::Candidate>;
0847   using TauRef = edm::Ref<TauCollection>;
0848   using TauRefProd = edm::RefProd<TauCollection>;
0849   using ElectronCollection = pat::ElectronCollection;
0850   using MuonCollection = pat::MuonCollection;
0851   using LorentzVectorXYZ = ROOT::Math::LorentzVector<ROOT::Math::PxPyPzE4D<double>>;
0852   using Cutter = tau::TauWPThreshold;
0853   using CutterPtr = std::unique_ptr<Cutter>;
0854   using WPList = std::vector<CutterPtr>;
0855 
0856   struct IDOutput {
0857     std::vector<size_t> num_, den_;
0858 
0859     IDOutput(const std::vector<size_t>& num, const std::vector<size_t>& den) : num_(num), den_(den) {}
0860 
0861     // for direct inference, read the output tensorflow::Tensor
0862     float read_value(const tensorflow::Tensor& pred, size_t tau_index, size_t elem) const {
0863       return pred.matrix<float>()(tau_index, elem);
0864     }
0865     // for SONIC, read the output vector
0866     float read_value(const std::vector<std::vector<float>>& pred, size_t tau_index, size_t elem) const {
0867       return pred.at(tau_index).at(elem);
0868     }
0869 
0870     template <typename PredType>
0871     std::unique_ptr<TauDiscriminator> get_value(const edm::Handle<TauCollection>& taus,
0872                                                 const PredType& pred,
0873                                                 const WPList* working_points,
0874                                                 bool is_online) const {
0875       std::vector<reco::SingleTauDiscriminatorContainer> outputbuffer(taus->size());
0876 
0877       for (size_t tau_index = 0; tau_index < taus->size(); ++tau_index) {
0878         float x = 0;
0879         for (size_t num_elem : num_)
0880           x += read_value(pred, tau_index, num_elem);
0881         if (x != 0 && !den_.empty()) {
0882           float den_val = 0;
0883           for (size_t den_elem : den_)
0884             den_val += read_value(pred, tau_index, den_elem);
0885           x = den_val != 0 ? x / den_val : std::numeric_limits<float>::max();
0886         }
0887         outputbuffer[tau_index].rawValues.push_back(x);
0888         if (working_points) {
0889           for (const auto& wp : *working_points) {
0890             const bool pass = x > (*wp)(taus->at(tau_index), is_online);
0891             outputbuffer[tau_index].workingPoints.push_back(pass);
0892           }
0893         }
0894       }
0895       std::unique_ptr<TauDiscriminator> output = std::make_unique<TauDiscriminator>();
0896       reco::TauDiscriminatorContainer::Filler filler(*output);
0897       filler.insert(taus, outputbuffer.begin(), outputbuffer.end());
0898       filler.fill();
0899       return output;
0900     }
0901   };
0902 
0903   using IDOutputCollection = std::map<std::string, IDOutput>;
0904 
0905   static constexpr float default_value = -999.;
0906 
0907   static const IDOutputCollection& GetIDOutputs() {
0908     static constexpr size_t e_index = 0, mu_index = 1, tau_index = 2, jet_index = 3;
0909     static const IDOutputCollection idoutputs_ = {
0910         {"VSe", IDOutput({tau_index}, {e_index, tau_index})},
0911         {"VSmu", IDOutput({tau_index}, {mu_index, tau_index})},
0912         {"VSjet", IDOutput({tau_index}, {jet_index, tau_index})},
0913     };
0914     return idoutputs_;
0915   }
0916 
0917   using BasicDiscriminator = deep_tau::BasicDiscriminator;
0918 
0919   const std::map<BasicDiscriminator, size_t> matchDiscriminatorIndices(
0920       edm::Event const& event,
0921       edm::EDGetTokenT<reco::TauDiscriminatorContainer> discriminatorContainerToken,
0922       std::vector<BasicDiscriminator> requiredDiscr) {
0923     std::map<std::string, size_t> discrIndexMapStr;
0924     auto const aHandle = event.getHandle(discriminatorContainerToken);
0925     auto const aProv = aHandle.provenance();
0926     if (aProv == nullptr)
0927       aHandle.whyFailed()->raise();
0928     const auto& psetsFromProvenance = edm::parameterSet(aProv->stable(), event.processHistory());
0929     auto const idlist = psetsFromProvenance.getParameter<std::vector<edm::ParameterSet>>("IDdefinitions");
0930     for (size_t j = 0; j < idlist.size(); ++j) {
0931       std::string idname = idlist[j].getParameter<std::string>("IDname");
0932       if (discrIndexMapStr.count(idname)) {
0933         throw cms::Exception("DeepTauId")
0934             << "basic discriminator " << idname << " appears more than once in the input.";
0935       }
0936       discrIndexMapStr[idname] = j;
0937     }
0938 
0939     //translate to a map of <BasicDiscriminator, index> and check if all discriminators are present
0940     std::map<BasicDiscriminator, size_t> discrIndexMap;
0941     for (size_t i = 0; i < requiredDiscr.size(); i++) {
0942       if (discrIndexMapStr.find(stringFromDiscriminator_.at(requiredDiscr[i])) == discrIndexMapStr.end())
0943         throw cms::Exception("DeepTauId") << "Basic Discriminator " << stringFromDiscriminator_.at(requiredDiscr[i])
0944                                           << " was not provided in the config file.";
0945       else
0946         discrIndexMap[requiredDiscr[i]] = discrIndexMapStr[stringFromDiscriminator_.at(requiredDiscr[i])];
0947     }
0948     return discrIndexMap;
0949   }
0950 
0951   static void fillDescriptionsHelper(edm::ParameterSetDescription& desc) {
0952     desc.add<edm::InputTag>("electrons", edm::InputTag("slimmedElectrons"));
0953     desc.add<edm::InputTag>("muons", edm::InputTag("slimmedMuons"));
0954     desc.add<edm::InputTag>("taus", edm::InputTag("slimmedTaus"));
0955     desc.add<edm::InputTag>("pfcands", edm::InputTag("packedPFCandidates"));
0956     desc.add<edm::InputTag>("vertices", edm::InputTag("offlineSlimmedPrimaryVertices"));
0957     desc.add<edm::InputTag>("rho", edm::InputTag("fixedGridRhoAll"));
0958     desc.add<bool>("mem_mapped", false);
0959     desc.add<unsigned>("year", 2017);
0960     desc.add<unsigned>("version", 2);
0961     desc.add<unsigned>("sub_version", 1);
0962     desc.add<int>("debug_level", 0);
0963     desc.add<bool>("disable_dxy_pca", false);
0964     desc.add<bool>("disable_hcalFraction_workaround", false);
0965     desc.add<bool>("disable_CellIndex_workaround", false);
0966     desc.add<bool>("save_inputs", false);
0967     desc.add<bool>("is_online", false);
0968 
0969     desc.add<std::vector<std::string>>("VSeWP", {"-1."});
0970     desc.add<std::vector<std::string>>("VSmuWP", {"-1."});
0971     desc.add<std::vector<std::string>>("VSjetWP", {"-1."});
0972 
0973     desc.addUntracked<edm::InputTag>("basicTauDiscriminators", edm::InputTag("basicTauDiscriminators"));
0974     desc.addUntracked<edm::InputTag>("basicTauDiscriminatorsdR03", edm::InputTag("basicTauDiscriminatorsdR03"));
0975     desc.add<edm::InputTag>("pfTauTransverseImpactParameters", edm::InputTag("hpsPFTauTransverseImpactParameters"));
0976 
0977     {
0978       edm::ParameterSetDescription pset_Prediscriminants;
0979       pset_Prediscriminants.add<std::string>("BooleanOperator", "and");
0980       {
0981         edm::ParameterSetDescription psd1;
0982         psd1.add<double>("cut");
0983         psd1.add<edm::InputTag>("Producer");
0984         pset_Prediscriminants.addOptional<edm::ParameterSetDescription>("decayMode", psd1);
0985       }
0986       desc.add<edm::ParameterSetDescription>("Prediscriminants", pset_Prediscriminants);
0987     }
0988   }
0989 
0990 public:
0991   DeepTauIdBase(const edm::ParameterSet& cfg)
0992       : Producer(cfg),
0993         tausToken_(this->template consumes<TauCollection>(cfg.getParameter<edm::InputTag>("taus"))),
0994         pfcandToken_(this->template consumes<CandidateCollection>(cfg.getParameter<edm::InputTag>("pfcands"))),
0995         vtxToken_(this->template consumes<reco::VertexCollection>(cfg.getParameter<edm::InputTag>("vertices"))),
0996         is_online_(cfg.getParameter<bool>("is_online")),
0997         idoutputs_(GetIDOutputs()),
0998         electrons_token_(
0999             this->template consumes<std::vector<pat::Electron>>(cfg.getParameter<edm::InputTag>("electrons"))),
1000         muons_token_(this->template consumes<std::vector<pat::Muon>>(cfg.getParameter<edm::InputTag>("muons"))),
1001         rho_token_(this->template consumes<double>(cfg.getParameter<edm::InputTag>("rho"))),
1002         basicTauDiscriminators_inputToken_(this->template consumes<reco::TauDiscriminatorContainer>(
1003             cfg.getUntrackedParameter<edm::InputTag>("basicTauDiscriminators"))),
1004         basicTauDiscriminatorsdR03_inputToken_(this->template consumes<reco::TauDiscriminatorContainer>(
1005             cfg.getUntrackedParameter<edm::InputTag>("basicTauDiscriminatorsdR03"))),
1006         pfTauTransverseImpactParameters_token_(
1007             this->template consumes<
1008                 edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>>(
1009                 cfg.getParameter<edm::InputTag>("pfTauTransverseImpactParameters"))),
1010         year_(cfg.getParameter<unsigned>("year")),
1011         version_(cfg.getParameter<unsigned>("version")),
1012         sub_version_(cfg.getParameter<unsigned>("sub_version")),
1013         debug_level(cfg.getParameter<int>("debug_level")),
1014         disable_dxy_pca_(cfg.getParameter<bool>("disable_dxy_pca")),
1015         disable_hcalFraction_workaround_(cfg.getParameter<bool>("disable_hcalFraction_workaround")),
1016         disable_CellIndex_workaround_(cfg.getParameter<bool>("disable_CellIndex_workaround")),
1017         save_inputs_(cfg.getParameter<bool>("save_inputs")),
1018         json_file_(nullptr),
1019         file_counter_(0) {
1020     for (const auto& output_desc : idoutputs_) {
1021       this->template produces<TauDiscriminator>(output_desc.first);
1022       const auto& cut_list = cfg.getParameter<std::vector<std::string>>(output_desc.first + "WP");
1023       for (const std::string& cut_str : cut_list) {
1024         workingPoints_[output_desc.first].push_back(std::make_unique<Cutter>(cut_str));
1025       }
1026     }
1027 
1028     // prediscriminant operator
1029     // require the tau to pass the following prediscriminants
1030     const edm::ParameterSet& prediscriminantConfig = cfg.getParameter<edm::ParameterSet>("Prediscriminants");
1031 
1032     // determine boolean operator used on the prediscriminants
1033     std::string pdBoolOperator = prediscriminantConfig.getParameter<std::string>("BooleanOperator");
1034     // convert string to lowercase
1035     transform(pdBoolOperator.begin(), pdBoolOperator.end(), pdBoolOperator.begin(), ::tolower);
1036 
1037     if (pdBoolOperator == "and") {
1038       andPrediscriminants_ = 0x1;  //use chars instead of bools so we can do a bitwise trick later
1039     } else if (pdBoolOperator == "or") {
1040       andPrediscriminants_ = 0x0;
1041     } else {
1042       throw cms::Exception("TauDiscriminationProducerBase")
1043           << "PrediscriminantBooleanOperator defined incorrectly, options are: AND,OR";
1044     }
1045 
1046     // get the list of prediscriminants
1047     std::vector<std::string> prediscriminantsNames =
1048         prediscriminantConfig.getParameterNamesForType<edm::ParameterSet>();
1049 
1050     for (auto const& iDisc : prediscriminantsNames) {
1051       const edm::ParameterSet& iPredisc = prediscriminantConfig.getParameter<edm::ParameterSet>(iDisc);
1052       const edm::InputTag& label = iPredisc.getParameter<edm::InputTag>("Producer");
1053       double cut = iPredisc.getParameter<double>("cut");
1054 
1055       if (is_online_) {
1056         TauDiscInfo<reco::PFTauDiscriminator> thisDiscriminator;
1057         thisDiscriminator.label = label;
1058         thisDiscriminator.cut = cut;
1059         thisDiscriminator.disc_token = this->template consumes<reco::PFTauDiscriminator>(label);
1060         recoPrediscriminants_.push_back(thisDiscriminator);
1061       } else {
1062         TauDiscInfo<pat::PATTauDiscriminator> thisDiscriminator;
1063         thisDiscriminator.label = label;
1064         thisDiscriminator.cut = cut;
1065         thisDiscriminator.disc_token = this->template consumes<pat::PATTauDiscriminator>(label);
1066         patPrediscriminants_.push_back(thisDiscriminator);
1067       }
1068     }
1069     if (version_ == 2) {
1070       using namespace dnn_inputs_v2;
1071       namespace sc = deep_tau::Scaling;
1072       tauInputs_indices_.resize(TauBlockInputs::NumberOfInputs);
1073       std::iota(std::begin(tauInputs_indices_), std::end(tauInputs_indices_), 0);
1074 
1075       if (sub_version_ == 1) {
1076         scalingParamsMap_ = &sc::scalingParamsMap_v2p1;
1077       } else if ((sub_version_ == 5) || ((sub_version_ == 0) && (year_ == 20161718))) {
1078         std::sort(TauBlockInputs::varsToDrop.begin(), TauBlockInputs::varsToDrop.end());
1079         for (auto v : TauBlockInputs::varsToDrop) {
1080           tauInputs_indices_.at(v) = TauBlockInputs::NumberOfInputs - TauBlockInputs::varsToDrop.size();
1081           for (std::size_t i = v + 1; i < TauBlockInputs::NumberOfInputs; ++i)
1082             tauInputs_indices_.at(i) -= 1;  // shift all the following indices by 1
1083         }
1084         if (year_ == 2026) {
1085           scalingParamsMap_ = &sc::scalingParamsMap_PhaseIIv2p5;
1086         } else if ((sub_version_ == 0) && (year_ == 20161718)) {
1087           scalingParamsMap_ = &sc::scalingParamsMap_BoostedRun2_v2p0;
1088         } else {
1089           scalingParamsMap_ = &sc::scalingParamsMap_v2p5;
1090         }
1091       } else
1092         throw cms::Exception("DeepTauId") << "subversion " << sub_version_ << " is not supported.";
1093 
1094       std::map<std::vector<bool>, std::vector<sc::FeatureT>> GridFeatureTypes_map = {
1095           {{false}, {sc::FeatureT::TauFlat, sc::FeatureT::GridGlobal}},  // feature types without inner/outer grid split
1096           {{false, true},
1097            {sc::FeatureT::PfCand_electron,
1098             sc::FeatureT::PfCand_muon,  // feature types with inner/outer grid split
1099             sc::FeatureT::PfCand_chHad,
1100             sc::FeatureT::PfCand_nHad,
1101             sc::FeatureT::PfCand_gamma,
1102             sc::FeatureT::Electron,
1103             sc::FeatureT::Muon}}};
1104 
1105       // check that sizes of mean/std/lim_min/lim_max vectors are equal between each other
1106       for (const auto& p : GridFeatureTypes_map) {
1107         for (auto is_inner : p.first) {
1108           for (auto featureType : p.second) {
1109             const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(featureType, is_inner));
1110             if (!(sp.mean_.size() == sp.std_.size() && sp.mean_.size() == sp.lim_min_.size() &&
1111                   sp.mean_.size() == sp.lim_max_.size()))
1112               throw cms::Exception("DeepTauId") << "sizes of scaling parameter vectors do not match between each other";
1113           }
1114         }
1115       }
1116     } else {
1117       throw cms::Exception("DeepTauId") << "version " << version_ << " is not supported.";
1118     }
1119   }
1120 
1121   template <typename ConsumeType>
1122   struct TauDiscInfo {
1123     edm::InputTag label;
1124     edm::Handle<ConsumeType> handle;
1125     edm::EDGetTokenT<ConsumeType> disc_token;
1126     double cut;
1127     void fill(const edm::Event& evt) { evt.getByToken(disc_token, handle); }
1128   };
1129 
1130   // select boolean operation on prediscriminants (and = 0x01, or = 0x00)
1131   uint8_t andPrediscriminants_;
1132   std::vector<TauDiscInfo<pat::PATTauDiscriminator>> patPrediscriminants_;
1133   std::vector<TauDiscInfo<reco::PFTauDiscriminator>> recoPrediscriminants_;
1134 
1135 protected:
1136   static constexpr float pi = M_PI;
1137 
1138   template <typename PredType>
1139   void createOutputs(edm::Event& event, const PredType& pred, edm::Handle<TauCollection> taus) {
1140     for (const auto& output_desc : idoutputs_) {
1141       const WPList* working_points = nullptr;
1142       if (workingPoints_.find(output_desc.first) != workingPoints_.end()) {
1143         working_points = &workingPoints_.at(output_desc.first);
1144       }
1145       auto result = output_desc.second.get_value(taus, pred, working_points, is_online_);
1146       event.put(std::move(result), output_desc.first);
1147     }
1148   }
1149 
1150   template <typename T>
1151   static float getValue(T value) {
1152     return std::isnormal(value) ? static_cast<float>(value) : 0.f;
1153   }
1154 
1155   template <typename T>
1156   static float getValueLinear(T value, float min_value, float max_value, bool positive) {
1157     const float fixed_value = getValue(value);
1158     const float clamped_value = std::clamp(fixed_value, min_value, max_value);
1159     float transformed_value = (clamped_value - min_value) / (max_value - min_value);
1160     if (!positive)
1161       transformed_value = transformed_value * 2 - 1;
1162     return transformed_value;
1163   }
1164 
1165   template <typename T>
1166   static float getValueNorm(T value, float mean, float sigma, float n_sigmas_max = 5) {
1167     const float fixed_value = getValue(value);
1168     const float norm_value = (fixed_value - mean) / sigma;
1169     return std::clamp(norm_value, -n_sigmas_max, n_sigmas_max);
1170   }
1171 
1172   static bool isAbove(double value, double min) { return std::isnormal(value) && value > min; }
1173 
1174   static bool calculateElectronClusterVarsV2(const pat::Electron& ele,
1175                                              float& cc_ele_energy,
1176                                              float& cc_gamma_energy,
1177                                              int& cc_n_gamma) {
1178     cc_ele_energy = cc_gamma_energy = 0;
1179     cc_n_gamma = 0;
1180     const auto& superCluster = ele.superCluster();
1181     if (superCluster.isNonnull() && superCluster.isAvailable() && superCluster->clusters().isNonnull() &&
1182         superCluster->clusters().isAvailable()) {
1183       for (auto iter = superCluster->clustersBegin(); iter != superCluster->clustersEnd(); ++iter) {
1184         const float energy = static_cast<float>((*iter)->energy());
1185         if (iter == superCluster->clustersBegin())
1186           cc_ele_energy += energy;
1187         else {
1188           cc_gamma_energy += energy;
1189           ++cc_n_gamma;
1190         }
1191       }
1192       return true;
1193     } else
1194       return false;
1195   }
1196 
1197 protected:
1198   // load prediscriminators
1199   void loadPrediscriminants(edm::Event const& event, edm::Handle<TauCollection> const& taus) {
1200     edm::ProductID tauProductID = taus.id();
1201     size_t nPrediscriminants =
1202         patPrediscriminants_.empty() ? recoPrediscriminants_.size() : patPrediscriminants_.size();
1203     for (size_t iDisc = 0; iDisc < nPrediscriminants; ++iDisc) {
1204       edm::ProductID discKeyId;
1205       if (is_online_) {
1206         recoPrediscriminants_[iDisc].fill(event);
1207         discKeyId = recoPrediscriminants_[iDisc].handle->keyProduct().id();
1208       } else {
1209         patPrediscriminants_[iDisc].fill(event);
1210         discKeyId = patPrediscriminants_[iDisc].handle->keyProduct().id();
1211       }
1212 
1213       // Check to make sure the product is correct for the discriminator.
1214       // If not, throw a more informative exception.
1215       if (tauProductID != discKeyId) {
1216         throw cms::Exception("MisconfiguredPrediscriminant")
1217             << "The tau collection has product ID: " << tauProductID
1218             << " but the pre-discriminator is keyed with product ID: " << discKeyId << std::endl;
1219       }
1220     }
1221   }
1222 
1223   template <typename Collection, typename TauCastType>
1224   void fillGrids(const TauCastType& tau, const Collection& objects, CellGrid& inner_grid, CellGrid& outer_grid) {
1225     static constexpr double outer_dR2 = 0.25;  //0.5^2
1226     const double inner_radius = getInnerSignalConeRadius(tau.polarP4().pt());
1227     const double inner_dR2 = std::pow(inner_radius, 2);
1228 
1229     const auto addObject = [&](size_t n, double deta, double dphi, CellGrid& grid) {
1230       const auto& obj = objects.at(n);
1231       const CellObjectType obj_type = GetCellObjectType(obj);
1232       if (obj_type == CellObjectType::Other)
1233         return;
1234       CellIndex cell_index;
1235       if (grid.tryGetCellIndex(deta, dphi, cell_index)) {
1236         Cell& cell = grid[cell_index];
1237         auto iter = cell.find(obj_type);
1238         if (iter != cell.end()) {
1239           const auto& prev_obj = objects.at(iter->second);
1240           if (obj.polarP4().pt() > prev_obj.polarP4().pt())
1241             iter->second = n;
1242         } else {
1243           cell[obj_type] = n;
1244         }
1245       }
1246     };
1247 
1248     for (size_t n = 0; n < objects.size(); ++n) {
1249       const auto& obj = objects.at(n);
1250       const double deta = obj.polarP4().eta() - tau.polarP4().eta();
1251       const double dphi = reco::deltaPhi(obj.polarP4().phi(), tau.polarP4().phi());
1252       const double dR2 = std::pow(deta, 2) + std::pow(dphi, 2);
1253       if (dR2 < inner_dR2)
1254         addObject(n, deta, dphi, inner_grid);
1255       if (dR2 < outer_dR2)
1256         addObject(n, deta, dphi, outer_grid);
1257     }
1258   }
1259 
1260   template <typename CandidateCastType, typename TauCastType, typename TauBlockType>
1261   void createTauBlockInputs(const TauCastType& tau,
1262                             const size_t& tau_index,
1263                             const edm::RefToBase<reco::BaseTau> tau_ref,
1264                             const reco::Vertex& pv,
1265                             double rho,
1266                             TauFunc tau_funcs,
1267                             TauBlockType& tauBlockInputs) {
1268     namespace dnn = dnn_inputs_v2::TauBlockInputs;
1269     namespace sc = deep_tau::Scaling;
1270     namespace candFunc = candFunc;
1271     sc::FeatureT ft = sc::FeatureT::TauFlat;
1272     const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft, false));
1273 
1274     const auto& get = [&](int var_index) -> float& {
1275       if constexpr (std::is_same_v<TauBlockType, std::vector<float>::iterator>) {
1276         return *(tauBlockInputs + tauInputs_indices_.at(var_index));
1277       } else {
1278         return ((tensorflow::Tensor)tauBlockInputs).matrix<float>()(0, tauInputs_indices_.at(var_index));
1279       }
1280     };
1281 
1282     auto leadChargedHadrCand = dynamic_cast<const CandidateCastType*>(tau.leadChargedHadrCand().get());
1283 
1284     get(dnn::rho) = sp.scale(rho, tauInputs_indices_[dnn::rho]);
1285     get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), tauInputs_indices_[dnn::tau_pt]);
1286     get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), tauInputs_indices_[dnn::tau_eta]);
1287     if (sub_version_ == 1) {
1288       get(dnn::tau_phi) = getValueLinear(tau.polarP4().phi(), -pi, pi, false);
1289     }
1290     get(dnn::tau_mass) = sp.scale(tau.polarP4().mass(), tauInputs_indices_[dnn::tau_mass]);
1291     get(dnn::tau_E_over_pt) = sp.scale(tau.p4().energy() / tau.p4().pt(), tauInputs_indices_[dnn::tau_E_over_pt]);
1292     get(dnn::tau_charge) = sp.scale(tau.charge(), tauInputs_indices_[dnn::tau_charge]);
1293     get(dnn::tau_n_charged_prongs) = sp.scale(tau.decayMode() / 5 + 1, tauInputs_indices_[dnn::tau_n_charged_prongs]);
1294     get(dnn::tau_n_neutral_prongs) = sp.scale(tau.decayMode() % 5, tauInputs_indices_[dnn::tau_n_neutral_prongs]);
1295     get(dnn::chargedIsoPtSum) =
1296         sp.scale(tau_funcs.getChargedIsoPtSum(tau, tau_ref), tauInputs_indices_[dnn::chargedIsoPtSum]);
1297     get(dnn::chargedIsoPtSumdR03_over_dR05) =
1298         sp.scale(tau_funcs.getChargedIsoPtSumdR03(tau, tau_ref) / tau_funcs.getChargedIsoPtSum(tau, tau_ref),
1299                  tauInputs_indices_[dnn::chargedIsoPtSumdR03_over_dR05]);
1300     if (sub_version_ == 1)
1301       get(dnn::footprintCorrection) =
1302           sp.scale(tau_funcs.getFootprintCorrectiondR03(tau, tau_ref), tauInputs_indices_[dnn::footprintCorrection]);
1303     else if ((sub_version_ == 5) || ((sub_version_ == 0) && (year_ == 20161718))) {
1304       if (is_online_)
1305         get(dnn::footprintCorrection) =
1306             sp.scale(tau_funcs.getFootprintCorrectiondR03(tau, tau_ref), tauInputs_indices_[dnn::footprintCorrection]);
1307       else
1308         get(dnn::footprintCorrection) =
1309             sp.scale(tau_funcs.getFootprintCorrection(tau, tau_ref), tauInputs_indices_[dnn::footprintCorrection]);
1310     }
1311     get(dnn::neutralIsoPtSum) =
1312         sp.scale(tau_funcs.getNeutralIsoPtSum(tau, tau_ref), tauInputs_indices_[dnn::neutralIsoPtSum]);
1313     get(dnn::neutralIsoPtSumWeight_over_neutralIsoPtSum) =
1314         sp.scale(tau_funcs.getNeutralIsoPtSumWeight(tau, tau_ref) / tau_funcs.getNeutralIsoPtSum(tau, tau_ref),
1315                  tauInputs_indices_[dnn::neutralIsoPtSumWeight_over_neutralIsoPtSum]);
1316     get(dnn::neutralIsoPtSumWeightdR03_over_neutralIsoPtSum) =
1317         sp.scale(tau_funcs.getNeutralIsoPtSumdR03Weight(tau, tau_ref) / tau_funcs.getNeutralIsoPtSum(tau, tau_ref),
1318                  tauInputs_indices_[dnn::neutralIsoPtSumWeightdR03_over_neutralIsoPtSum]);
1319     get(dnn::neutralIsoPtSumdR03_over_dR05) =
1320         sp.scale(tau_funcs.getNeutralIsoPtSumdR03(tau, tau_ref) / tau_funcs.getNeutralIsoPtSum(tau, tau_ref),
1321                  tauInputs_indices_[dnn::neutralIsoPtSumdR03_over_dR05]);
1322     get(dnn::photonPtSumOutsideSignalCone) = sp.scale(tau_funcs.getPhotonPtSumOutsideSignalCone(tau, tau_ref),
1323                                                       tauInputs_indices_[dnn::photonPtSumOutsideSignalCone]);
1324     get(dnn::puCorrPtSum) = sp.scale(tau_funcs.getPuCorrPtSum(tau, tau_ref), tauInputs_indices_[dnn::puCorrPtSum]);
1325     // The global PCA coordinates were used as inputs during the NN training, but it was decided to disable
1326     // them for the inference, because modeling of dxy_PCA in MC poorly describes the data, and x and y coordinates
1327     // in data results outside of the expected 5 std. dev. input validity range. On the other hand,
1328     // these coordinates are strongly era-dependent. Kept as comment to document what NN expects.
1329     if (sub_version_ == 1) {
1330       if (!disable_dxy_pca_) {
1331         auto const pca = tau_funcs.getdxyPCA(tau, tau_index);
1332         get(dnn::tau_dxy_pca_x) = sp.scale(pca.x(), tauInputs_indices_[dnn::tau_dxy_pca_x]);
1333         get(dnn::tau_dxy_pca_y) = sp.scale(pca.y(), tauInputs_indices_[dnn::tau_dxy_pca_y]);
1334         get(dnn::tau_dxy_pca_z) = sp.scale(pca.z(), tauInputs_indices_[dnn::tau_dxy_pca_z]);
1335       } else {
1336         get(dnn::tau_dxy_pca_x) = 0;
1337         get(dnn::tau_dxy_pca_y) = 0;
1338         get(dnn::tau_dxy_pca_z) = 0;
1339       }
1340     }
1341 
1342     const bool tau_dxy_valid =
1343         isAbove(tau_funcs.getdxy(tau, tau_index), -10) && isAbove(tau_funcs.getdxyError(tau, tau_index), 0);
1344     if (tau_dxy_valid) {
1345       get(dnn::tau_dxy_valid) = sp.scale(tau_dxy_valid, tauInputs_indices_[dnn::tau_dxy_valid]);
1346       get(dnn::tau_dxy) = sp.scale(tau_funcs.getdxy(tau, tau_index), tauInputs_indices_[dnn::tau_dxy]);
1347       get(dnn::tau_dxy_sig) =
1348           sp.scale(std::abs(tau_funcs.getdxy(tau, tau_index)) / tau_funcs.getdxyError(tau, tau_index),
1349                    tauInputs_indices_[dnn::tau_dxy_sig]);
1350     }
1351     const bool tau_ip3d_valid =
1352         isAbove(tau_funcs.getip3d(tau, tau_index), -10) && isAbove(tau_funcs.getip3dError(tau, tau_index), 0);
1353     if (tau_ip3d_valid) {
1354       get(dnn::tau_ip3d_valid) = sp.scale(tau_ip3d_valid, tauInputs_indices_[dnn::tau_ip3d_valid]);
1355       get(dnn::tau_ip3d) = sp.scale(tau_funcs.getip3d(tau, tau_index), tauInputs_indices_[dnn::tau_ip3d]);
1356       get(dnn::tau_ip3d_sig) =
1357           sp.scale(std::abs(tau_funcs.getip3d(tau, tau_index)) / tau_funcs.getip3dError(tau, tau_index),
1358                    tauInputs_indices_[dnn::tau_ip3d_sig]);
1359     }
1360     if (leadChargedHadrCand) {
1361       const bool hasTrackDetails = candFunc::getHasTrackDetails(*leadChargedHadrCand);
1362       const float tau_dz = (is_online_ && !hasTrackDetails) ? 0 : candFunc::getTauDz(*leadChargedHadrCand);
1363       get(dnn::tau_dz) = sp.scale(tau_dz, tauInputs_indices_[dnn::tau_dz]);
1364       get(dnn::tau_dz_sig_valid) =
1365           sp.scale(candFunc::getTauDZSigValid(*leadChargedHadrCand), tauInputs_indices_[dnn::tau_dz_sig_valid]);
1366       const double dzError = hasTrackDetails ? leadChargedHadrCand->dzError() : -999.;
1367       get(dnn::tau_dz_sig) = sp.scale(std::abs(tau_dz) / dzError, tauInputs_indices_[dnn::tau_dz_sig]);
1368     }
1369     get(dnn::tau_flightLength_x) =
1370         sp.scale(tau_funcs.getFlightLength(tau, tau_index).x(), tauInputs_indices_[dnn::tau_flightLength_x]);
1371     get(dnn::tau_flightLength_y) =
1372         sp.scale(tau_funcs.getFlightLength(tau, tau_index).y(), tauInputs_indices_[dnn::tau_flightLength_y]);
1373     get(dnn::tau_flightLength_z) =
1374         sp.scale(tau_funcs.getFlightLength(tau, tau_index).z(), tauInputs_indices_[dnn::tau_flightLength_z]);
1375     if (sub_version_ == 1)
1376       get(dnn::tau_flightLength_sig) = 0.55756444;  //This value is set due to a bug in the training
1377     else if ((sub_version_ == 5) || ((sub_version_ == 0) && (year_ == 20161718)))
1378       get(dnn::tau_flightLength_sig) =
1379           sp.scale(tau_funcs.getFlightLengthSig(tau, tau_index), tauInputs_indices_[dnn::tau_flightLength_sig]);
1380 
1381     get(dnn::tau_pt_weighted_deta_strip) = sp.scale(reco::tau::pt_weighted_deta_strip(tau, tau.decayMode()),
1382                                                     tauInputs_indices_[dnn::tau_pt_weighted_deta_strip]);
1383 
1384     get(dnn::tau_pt_weighted_dphi_strip) = sp.scale(reco::tau::pt_weighted_dphi_strip(tau, tau.decayMode()),
1385                                                     tauInputs_indices_[dnn::tau_pt_weighted_dphi_strip]);
1386     get(dnn::tau_pt_weighted_dr_signal) = sp.scale(reco::tau::pt_weighted_dr_signal(tau, tau.decayMode()),
1387                                                    tauInputs_indices_[dnn::tau_pt_weighted_dr_signal]);
1388     get(dnn::tau_pt_weighted_dr_iso) =
1389         sp.scale(reco::tau::pt_weighted_dr_iso(tau, tau.decayMode()), tauInputs_indices_[dnn::tau_pt_weighted_dr_iso]);
1390     get(dnn::tau_leadingTrackNormChi2) =
1391         sp.scale(tau_funcs.getLeadingTrackNormChi2(tau), tauInputs_indices_[dnn::tau_leadingTrackNormChi2]);
1392     const auto eratio = reco::tau::eratio(tau);
1393     const bool tau_e_ratio_valid = std::isnormal(eratio) && eratio > 0.f;
1394     get(dnn::tau_e_ratio_valid) = sp.scale(tau_e_ratio_valid, tauInputs_indices_[dnn::tau_e_ratio_valid]);
1395     get(dnn::tau_e_ratio) = tau_e_ratio_valid ? sp.scale(eratio, tauInputs_indices_[dnn::tau_e_ratio]) : 0.f;
1396     const double gj_angle_diff = calculateGottfriedJacksonAngleDifference(tau, tau_index, tau_funcs);
1397     const bool tau_gj_angle_diff_valid = (std::isnormal(gj_angle_diff) || gj_angle_diff == 0) && gj_angle_diff >= 0;
1398     get(dnn::tau_gj_angle_diff_valid) =
1399         sp.scale(tau_gj_angle_diff_valid, tauInputs_indices_[dnn::tau_gj_angle_diff_valid]);
1400     get(dnn::tau_gj_angle_diff) =
1401         tau_gj_angle_diff_valid ? sp.scale(gj_angle_diff, tauInputs_indices_[dnn::tau_gj_angle_diff]) : 0;
1402     get(dnn::tau_n_photons) = sp.scale(reco::tau::n_photons_total(tau), tauInputs_indices_[dnn::tau_n_photons]);
1403     get(dnn::tau_emFraction) = sp.scale(tau_funcs.getEmFraction(tau), tauInputs_indices_[dnn::tau_emFraction]);
1404 
1405     get(dnn::tau_inside_ecal_crack) =
1406         sp.scale(isInEcalCrack(tau.p4().eta()), tauInputs_indices_[dnn::tau_inside_ecal_crack]);
1407     get(dnn::leadChargedCand_etaAtEcalEntrance_minus_tau_eta) =
1408         sp.scale(tau_funcs.getEtaAtEcalEntrance(tau) - tau.p4().eta(),
1409                  tauInputs_indices_[dnn::leadChargedCand_etaAtEcalEntrance_minus_tau_eta]);
1410   }
1411 
1412   template <typename CandidateCastType, typename TauCastType, typename EgammaBlockType>
1413   void createEgammaBlockInputs(unsigned idx,
1414                                const TauCastType& tau,
1415                                const size_t tau_index,
1416                                const edm::RefToBase<reco::BaseTau> tau_ref,
1417                                const reco::Vertex& pv,
1418                                double rho,
1419                                const std::vector<pat::Electron>* electrons,
1420                                const edm::View<reco::Candidate>& pfCands,
1421                                const Cell& cell_map,
1422                                TauFunc tau_funcs,
1423                                bool is_inner,
1424                                EgammaBlockType& egammaBlockInputs) {
1425     namespace dnn = dnn_inputs_v2::EgammaBlockInputs;
1426     namespace sc = deep_tau::Scaling;
1427     namespace candFunc = candFunc;
1428     sc::FeatureT ft_global = sc::FeatureT::GridGlobal;
1429     sc::FeatureT ft_PFe = sc::FeatureT::PfCand_electron;
1430     sc::FeatureT ft_PFg = sc::FeatureT::PfCand_gamma;
1431     sc::FeatureT ft_e = sc::FeatureT::Electron;
1432 
1433     // needed to remap indices from scaling vectors to those from dnn_inputs_v2::EgammaBlockInputs
1434     int PFe_index_offset = scalingParamsMap_->at(std::make_pair(ft_global, false)).mean_.size();
1435     int e_index_offset = PFe_index_offset + scalingParamsMap_->at(std::make_pair(ft_PFe, false)).mean_.size();
1436     int PFg_index_offset = e_index_offset + scalingParamsMap_->at(std::make_pair(ft_e, false)).mean_.size();
1437 
1438     // to account for swapped order of PfCand_gamma and Electron blocks for v2p5 training w.r.t. v2p1
1439     int fill_index_offset_e = 0;
1440     int fill_index_offset_PFg = 0;
1441     if ((sub_version_ == 5) || ((sub_version_ == 0) && (year_ == 20161718))) {
1442       fill_index_offset_e =
1443           scalingParamsMap_->at(std::make_pair(ft_PFg, false)).mean_.size();  // size of PF gamma features
1444       fill_index_offset_PFg =
1445           -scalingParamsMap_->at(std::make_pair(ft_e, false)).mean_.size();  // size of Electron features
1446     }
1447 
1448     const auto& get = [&](int var_index) -> float& {
1449       if constexpr (std::is_same_v<EgammaBlockType, std::vector<float>::iterator>) {
1450         return *(egammaBlockInputs + var_index);
1451       } else {
1452         return ((tensorflow::Tensor)egammaBlockInputs).tensor<float, 4>()(idx, 0, 0, var_index);
1453       }
1454     };
1455 
1456     const bool valid_index_pf_ele = cell_map.count(CellObjectType::PfCand_electron);
1457     const bool valid_index_pf_gamma = cell_map.count(CellObjectType::PfCand_gamma);
1458     const bool valid_index_ele = cell_map.count(CellObjectType::Electron);
1459 
1460     if (!cell_map.empty()) {
1461       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_global, false));
1462       get(dnn::rho) = sp.scale(rho, dnn::rho);
1463       get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), dnn::tau_pt);
1464       get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), dnn::tau_eta);
1465       get(dnn::tau_inside_ecal_crack) = sp.scale(isInEcalCrack(tau.polarP4().eta()), dnn::tau_inside_ecal_crack);
1466     }
1467     if (valid_index_pf_ele) {
1468       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFe, is_inner));
1469       size_t index_pf_ele = cell_map.at(CellObjectType::PfCand_electron);
1470       const auto& ele_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_pf_ele));
1471 
1472       get(dnn::pfCand_ele_valid) = sp.scale(valid_index_pf_ele, dnn::pfCand_ele_valid - PFe_index_offset);
1473       get(dnn::pfCand_ele_rel_pt) =
1474           sp.scale(ele_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_ele_rel_pt - PFe_index_offset);
1475       get(dnn::pfCand_ele_deta) =
1476           sp.scale(ele_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_ele_deta - PFe_index_offset);
1477       get(dnn::pfCand_ele_dphi) =
1478           sp.scale(dPhi(tau.polarP4(), ele_cand.polarP4()), dnn::pfCand_ele_dphi - PFe_index_offset);
1479       get(dnn::pfCand_ele_pvAssociationQuality) = sp.scale<int>(
1480           candFunc::getPvAssocationQuality(ele_cand), dnn::pfCand_ele_pvAssociationQuality - PFe_index_offset);
1481       get(dnn::pfCand_ele_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(ele_cand, 0.9906834f),
1482                                                              dnn::pfCand_ele_puppiWeight - PFe_index_offset)
1483                                                   : sp.scale(candFunc::getPuppiWeight(ele_cand, 0.9669586f),
1484                                                              dnn::pfCand_ele_puppiWeight - PFe_index_offset);
1485       get(dnn::pfCand_ele_charge) = sp.scale(ele_cand.charge(), dnn::pfCand_ele_charge - PFe_index_offset);
1486       get(dnn::pfCand_ele_lostInnerHits) =
1487           sp.scale<int>(candFunc::getLostInnerHits(ele_cand, 0), dnn::pfCand_ele_lostInnerHits - PFe_index_offset);
1488       get(dnn::pfCand_ele_numberOfPixelHits) =
1489           sp.scale(candFunc::getNumberOfPixelHits(ele_cand, 0), dnn::pfCand_ele_numberOfPixelHits - PFe_index_offset);
1490       get(dnn::pfCand_ele_vertex_dx) =
1491           sp.scale(ele_cand.vertex().x() - pv.position().x(), dnn::pfCand_ele_vertex_dx - PFe_index_offset);
1492       get(dnn::pfCand_ele_vertex_dy) =
1493           sp.scale(ele_cand.vertex().y() - pv.position().y(), dnn::pfCand_ele_vertex_dy - PFe_index_offset);
1494       get(dnn::pfCand_ele_vertex_dz) =
1495           sp.scale(ele_cand.vertex().z() - pv.position().z(), dnn::pfCand_ele_vertex_dz - PFe_index_offset);
1496       get(dnn::pfCand_ele_vertex_dx_tauFL) =
1497           sp.scale(ele_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
1498                    dnn::pfCand_ele_vertex_dx_tauFL - PFe_index_offset);
1499       get(dnn::pfCand_ele_vertex_dy_tauFL) =
1500           sp.scale(ele_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
1501                    dnn::pfCand_ele_vertex_dy_tauFL - PFe_index_offset);
1502       get(dnn::pfCand_ele_vertex_dz_tauFL) =
1503           sp.scale(ele_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
1504                    dnn::pfCand_ele_vertex_dz_tauFL - PFe_index_offset);
1505 
1506       const bool hasTrackDetails = candFunc::getHasTrackDetails(ele_cand);
1507       if (hasTrackDetails) {
1508         get(dnn::pfCand_ele_hasTrackDetails) =
1509             sp.scale(hasTrackDetails, dnn::pfCand_ele_hasTrackDetails - PFe_index_offset);
1510         get(dnn::pfCand_ele_dxy) = sp.scale(candFunc::getTauDxy(ele_cand), dnn::pfCand_ele_dxy - PFe_index_offset);
1511         get(dnn::pfCand_ele_dxy_sig) = sp.scale(std::abs(candFunc::getTauDxy(ele_cand)) / ele_cand.dxyError(),
1512                                                 dnn::pfCand_ele_dxy_sig - PFe_index_offset);
1513         get(dnn::pfCand_ele_dz) = sp.scale(candFunc::getTauDz(ele_cand), dnn::pfCand_ele_dz - PFe_index_offset);
1514         get(dnn::pfCand_ele_dz_sig) = sp.scale(std::abs(candFunc::getTauDz(ele_cand)) / ele_cand.dzError(),
1515                                                dnn::pfCand_ele_dz_sig - PFe_index_offset);
1516         get(dnn::pfCand_ele_track_chi2_ndof) =
1517             candFunc::getPseudoTrack(ele_cand).ndof() > 0
1518                 ? sp.scale(candFunc::getPseudoTrack(ele_cand).chi2() / candFunc::getPseudoTrack(ele_cand).ndof(),
1519                            dnn::pfCand_ele_track_chi2_ndof - PFe_index_offset)
1520                 : 0;
1521         get(dnn::pfCand_ele_track_ndof) =
1522             candFunc::getPseudoTrack(ele_cand).ndof() > 0
1523                 ? sp.scale(candFunc::getPseudoTrack(ele_cand).ndof(), dnn::pfCand_ele_track_ndof - PFe_index_offset)
1524                 : 0;
1525       }
1526     }
1527     if (valid_index_pf_gamma) {
1528       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFg, is_inner));
1529       size_t index_pf_gamma = cell_map.at(CellObjectType::PfCand_gamma);
1530       const auto& gamma_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_pf_gamma));
1531 
1532       get(dnn::pfCand_gamma_valid + fill_index_offset_PFg) =
1533           sp.scale(valid_index_pf_gamma, dnn::pfCand_gamma_valid - PFg_index_offset);
1534       get(dnn::pfCand_gamma_rel_pt + fill_index_offset_PFg) =
1535           sp.scale(gamma_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_gamma_rel_pt - PFg_index_offset);
1536       get(dnn::pfCand_gamma_deta + fill_index_offset_PFg) =
1537           sp.scale(gamma_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_gamma_deta - PFg_index_offset);
1538       get(dnn::pfCand_gamma_dphi + fill_index_offset_PFg) =
1539           sp.scale(dPhi(tau.polarP4(), gamma_cand.polarP4()), dnn::pfCand_gamma_dphi - PFg_index_offset);
1540       get(dnn::pfCand_gamma_pvAssociationQuality + fill_index_offset_PFg) = sp.scale<int>(
1541           candFunc::getPvAssocationQuality(gamma_cand), dnn::pfCand_gamma_pvAssociationQuality - PFg_index_offset);
1542       get(dnn::pfCand_gamma_fromPV + fill_index_offset_PFg) =
1543           sp.scale<int>(candFunc::getFromPV(gamma_cand), dnn::pfCand_gamma_fromPV - PFg_index_offset);
1544       get(dnn::pfCand_gamma_puppiWeight + fill_index_offset_PFg) =
1545           is_inner ? sp.scale(candFunc::getPuppiWeight(gamma_cand, 0.9084110f),
1546                               dnn::pfCand_gamma_puppiWeight - PFg_index_offset)
1547                    : sp.scale(candFunc::getPuppiWeight(gamma_cand, 0.4211567f),
1548                               dnn::pfCand_gamma_puppiWeight - PFg_index_offset);
1549       get(dnn::pfCand_gamma_puppiWeightNoLep + fill_index_offset_PFg) =
1550           is_inner ? sp.scale(candFunc::getPuppiWeightNoLep(gamma_cand, 0.8857716f),
1551                               dnn::pfCand_gamma_puppiWeightNoLep - PFg_index_offset)
1552                    : sp.scale(candFunc::getPuppiWeightNoLep(gamma_cand, 0.3822604f),
1553                               dnn::pfCand_gamma_puppiWeightNoLep - PFg_index_offset);
1554       get(dnn::pfCand_gamma_lostInnerHits + fill_index_offset_PFg) =
1555           sp.scale<int>(candFunc::getLostInnerHits(gamma_cand, 0), dnn::pfCand_gamma_lostInnerHits - PFg_index_offset);
1556       get(dnn::pfCand_gamma_numberOfPixelHits + fill_index_offset_PFg) = sp.scale(
1557           candFunc::getNumberOfPixelHits(gamma_cand, 0), dnn::pfCand_gamma_numberOfPixelHits - PFg_index_offset);
1558       get(dnn::pfCand_gamma_vertex_dx + fill_index_offset_PFg) =
1559           sp.scale(gamma_cand.vertex().x() - pv.position().x(), dnn::pfCand_gamma_vertex_dx - PFg_index_offset);
1560       get(dnn::pfCand_gamma_vertex_dy + fill_index_offset_PFg) =
1561           sp.scale(gamma_cand.vertex().y() - pv.position().y(), dnn::pfCand_gamma_vertex_dy - PFg_index_offset);
1562       get(dnn::pfCand_gamma_vertex_dz + fill_index_offset_PFg) =
1563           sp.scale(gamma_cand.vertex().z() - pv.position().z(), dnn::pfCand_gamma_vertex_dz - PFg_index_offset);
1564       get(dnn::pfCand_gamma_vertex_dx_tauFL + fill_index_offset_PFg) =
1565           sp.scale(gamma_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
1566                    dnn::pfCand_gamma_vertex_dx_tauFL - PFg_index_offset);
1567       get(dnn::pfCand_gamma_vertex_dy_tauFL + fill_index_offset_PFg) =
1568           sp.scale(gamma_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
1569                    dnn::pfCand_gamma_vertex_dy_tauFL - PFg_index_offset);
1570       get(dnn::pfCand_gamma_vertex_dz_tauFL + fill_index_offset_PFg) =
1571           sp.scale(gamma_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
1572                    dnn::pfCand_gamma_vertex_dz_tauFL - PFg_index_offset);
1573       const bool hasTrackDetails = candFunc::getHasTrackDetails(gamma_cand);
1574       if (hasTrackDetails) {
1575         get(dnn::pfCand_gamma_hasTrackDetails + fill_index_offset_PFg) =
1576             sp.scale(hasTrackDetails, dnn::pfCand_gamma_hasTrackDetails - PFg_index_offset);
1577         get(dnn::pfCand_gamma_dxy + fill_index_offset_PFg) =
1578             sp.scale(candFunc::getTauDxy(gamma_cand), dnn::pfCand_gamma_dxy - PFg_index_offset);
1579         get(dnn::pfCand_gamma_dxy_sig + fill_index_offset_PFg) =
1580             sp.scale(std::abs(candFunc::getTauDxy(gamma_cand)) / gamma_cand.dxyError(),
1581                      dnn::pfCand_gamma_dxy_sig - PFg_index_offset);
1582         get(dnn::pfCand_gamma_dz + fill_index_offset_PFg) =
1583             sp.scale(candFunc::getTauDz(gamma_cand), dnn::pfCand_gamma_dz - PFg_index_offset);
1584         get(dnn::pfCand_gamma_dz_sig + fill_index_offset_PFg) =
1585             sp.scale(std::abs(candFunc::getTauDz(gamma_cand)) / gamma_cand.dzError(),
1586                      dnn::pfCand_gamma_dz_sig - PFg_index_offset);
1587         get(dnn::pfCand_gamma_track_chi2_ndof + fill_index_offset_PFg) =
1588             candFunc::getPseudoTrack(gamma_cand).ndof() > 0
1589                 ? sp.scale(candFunc::getPseudoTrack(gamma_cand).chi2() / candFunc::getPseudoTrack(gamma_cand).ndof(),
1590                            dnn::pfCand_gamma_track_chi2_ndof - PFg_index_offset)
1591                 : 0;
1592         get(dnn::pfCand_gamma_track_ndof + fill_index_offset_PFg) =
1593             candFunc::getPseudoTrack(gamma_cand).ndof() > 0
1594                 ? sp.scale(candFunc::getPseudoTrack(gamma_cand).ndof(), dnn::pfCand_gamma_track_ndof - PFg_index_offset)
1595                 : 0;
1596       }
1597     }
1598     if (valid_index_ele) {
1599       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_e, is_inner));
1600       size_t index_ele = cell_map.at(CellObjectType::Electron);
1601       const auto& ele = electrons->at(index_ele);
1602 
1603       get(dnn::ele_valid + fill_index_offset_e) = sp.scale(valid_index_ele, dnn::ele_valid - e_index_offset);
1604       get(dnn::ele_rel_pt + fill_index_offset_e) =
1605           sp.scale(ele.polarP4().pt() / tau.polarP4().pt(), dnn::ele_rel_pt - e_index_offset);
1606       get(dnn::ele_deta + fill_index_offset_e) =
1607           sp.scale(ele.polarP4().eta() - tau.polarP4().eta(), dnn::ele_deta - e_index_offset);
1608       get(dnn::ele_dphi + fill_index_offset_e) =
1609           sp.scale(dPhi(tau.polarP4(), ele.polarP4()), dnn::ele_dphi - e_index_offset);
1610 
1611       float cc_ele_energy, cc_gamma_energy;
1612       int cc_n_gamma;
1613       const bool cc_valid = calculateElectronClusterVarsV2(ele, cc_ele_energy, cc_gamma_energy, cc_n_gamma);
1614       if (cc_valid) {
1615         get(dnn::ele_cc_valid + fill_index_offset_e) = sp.scale(cc_valid, dnn::ele_cc_valid - e_index_offset);
1616         get(dnn::ele_cc_ele_rel_energy + fill_index_offset_e) =
1617             sp.scale(cc_ele_energy / ele.polarP4().pt(), dnn::ele_cc_ele_rel_energy - e_index_offset);
1618         get(dnn::ele_cc_gamma_rel_energy + fill_index_offset_e) =
1619             sp.scale(cc_gamma_energy / cc_ele_energy, dnn::ele_cc_gamma_rel_energy - e_index_offset);
1620         get(dnn::ele_cc_n_gamma + fill_index_offset_e) = sp.scale(cc_n_gamma, dnn::ele_cc_n_gamma - e_index_offset);
1621       }
1622       get(dnn::ele_rel_trackMomentumAtVtx + fill_index_offset_e) =
1623           sp.scale(ele.trackMomentumAtVtx().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumAtVtx - e_index_offset);
1624       get(dnn::ele_rel_trackMomentumAtCalo + fill_index_offset_e) = sp.scale(
1625           ele.trackMomentumAtCalo().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumAtCalo - e_index_offset);
1626       get(dnn::ele_rel_trackMomentumOut + fill_index_offset_e) =
1627           sp.scale(ele.trackMomentumOut().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumOut - e_index_offset);
1628       get(dnn::ele_rel_trackMomentumAtEleClus + fill_index_offset_e) = sp.scale(
1629           ele.trackMomentumAtEleClus().R() / ele.polarP4().pt(), dnn::ele_rel_trackMomentumAtEleClus - e_index_offset);
1630       get(dnn::ele_rel_trackMomentumAtVtxWithConstraint + fill_index_offset_e) =
1631           sp.scale(ele.trackMomentumAtVtxWithConstraint().R() / ele.polarP4().pt(),
1632                    dnn::ele_rel_trackMomentumAtVtxWithConstraint - e_index_offset);
1633       get(dnn::ele_rel_ecalEnergy + fill_index_offset_e) =
1634           sp.scale(ele.ecalEnergy() / ele.polarP4().pt(), dnn::ele_rel_ecalEnergy - e_index_offset);
1635       get(dnn::ele_ecalEnergy_sig + fill_index_offset_e) =
1636           sp.scale(ele.ecalEnergy() / ele.ecalEnergyError(), dnn::ele_ecalEnergy_sig - e_index_offset);
1637       get(dnn::ele_eSuperClusterOverP + fill_index_offset_e) =
1638           sp.scale(ele.eSuperClusterOverP(), dnn::ele_eSuperClusterOverP - e_index_offset);
1639       get(dnn::ele_eSeedClusterOverP + fill_index_offset_e) =
1640           sp.scale(ele.eSeedClusterOverP(), dnn::ele_eSeedClusterOverP - e_index_offset);
1641       get(dnn::ele_eSeedClusterOverPout + fill_index_offset_e) =
1642           sp.scale(ele.eSeedClusterOverPout(), dnn::ele_eSeedClusterOverPout - e_index_offset);
1643       get(dnn::ele_eEleClusterOverPout + fill_index_offset_e) =
1644           sp.scale(ele.eEleClusterOverPout(), dnn::ele_eEleClusterOverPout - e_index_offset);
1645       get(dnn::ele_deltaEtaSuperClusterTrackAtVtx + fill_index_offset_e) =
1646           sp.scale(ele.deltaEtaSuperClusterTrackAtVtx(), dnn::ele_deltaEtaSuperClusterTrackAtVtx - e_index_offset);
1647       get(dnn::ele_deltaEtaSeedClusterTrackAtCalo + fill_index_offset_e) =
1648           sp.scale(ele.deltaEtaSeedClusterTrackAtCalo(), dnn::ele_deltaEtaSeedClusterTrackAtCalo - e_index_offset);
1649       get(dnn::ele_deltaEtaEleClusterTrackAtCalo + fill_index_offset_e) =
1650           sp.scale(ele.deltaEtaEleClusterTrackAtCalo(), dnn::ele_deltaEtaEleClusterTrackAtCalo - e_index_offset);
1651       get(dnn::ele_deltaPhiEleClusterTrackAtCalo + fill_index_offset_e) =
1652           sp.scale(ele.deltaPhiEleClusterTrackAtCalo(), dnn::ele_deltaPhiEleClusterTrackAtCalo - e_index_offset);
1653       get(dnn::ele_deltaPhiSuperClusterTrackAtVtx + fill_index_offset_e) =
1654           sp.scale(ele.deltaPhiSuperClusterTrackAtVtx(), dnn::ele_deltaPhiSuperClusterTrackAtVtx - e_index_offset);
1655       get(dnn::ele_deltaPhiSeedClusterTrackAtCalo + fill_index_offset_e) =
1656           sp.scale(ele.deltaPhiSeedClusterTrackAtCalo(), dnn::ele_deltaPhiSeedClusterTrackAtCalo - e_index_offset);
1657       const bool mva_valid =
1658           (ele.mvaInput().earlyBrem > -2) ||
1659           (year_ !=
1660            2026);  // Known issue that input can be invalid in Phase2 samples (early/lateBrem==-2, hadEnergy==0, sigmaEtaEta/deltaEta==3.40282e+38). Unknown if also in Run2/3, so don't change there
1661       if (mva_valid) {
1662         get(dnn::ele_mvaInput_earlyBrem + fill_index_offset_e) =
1663             sp.scale(ele.mvaInput().earlyBrem, dnn::ele_mvaInput_earlyBrem - e_index_offset);
1664         get(dnn::ele_mvaInput_lateBrem + fill_index_offset_e) =
1665             sp.scale(ele.mvaInput().lateBrem, dnn::ele_mvaInput_lateBrem - e_index_offset);
1666         get(dnn::ele_mvaInput_sigmaEtaEta + fill_index_offset_e) =
1667             sp.scale(ele.mvaInput().sigmaEtaEta, dnn::ele_mvaInput_sigmaEtaEta - e_index_offset);
1668         get(dnn::ele_mvaInput_hadEnergy + fill_index_offset_e) =
1669             sp.scale(ele.mvaInput().hadEnergy, dnn::ele_mvaInput_hadEnergy - e_index_offset);
1670         get(dnn::ele_mvaInput_deltaEta + fill_index_offset_e) =
1671             sp.scale(ele.mvaInput().deltaEta, dnn::ele_mvaInput_deltaEta - e_index_offset);
1672       }
1673       const auto& gsfTrack = ele.gsfTrack();
1674       if (gsfTrack.isNonnull()) {
1675         get(dnn::ele_gsfTrack_normalizedChi2 + fill_index_offset_e) =
1676             sp.scale(gsfTrack->normalizedChi2(), dnn::ele_gsfTrack_normalizedChi2 - e_index_offset);
1677         get(dnn::ele_gsfTrack_numberOfValidHits + fill_index_offset_e) =
1678             sp.scale(gsfTrack->numberOfValidHits(), dnn::ele_gsfTrack_numberOfValidHits - e_index_offset);
1679         get(dnn::ele_rel_gsfTrack_pt + fill_index_offset_e) =
1680             sp.scale(gsfTrack->pt() / ele.polarP4().pt(), dnn::ele_rel_gsfTrack_pt - e_index_offset);
1681         get(dnn::ele_gsfTrack_pt_sig + fill_index_offset_e) =
1682             sp.scale(gsfTrack->pt() / gsfTrack->ptError(), dnn::ele_gsfTrack_pt_sig - e_index_offset);
1683       }
1684       const auto& closestCtfTrack = ele.closestCtfTrackRef();
1685       const bool has_closestCtfTrack = closestCtfTrack.isNonnull();
1686       if (has_closestCtfTrack) {
1687         get(dnn::ele_has_closestCtfTrack + fill_index_offset_e) =
1688             sp.scale(has_closestCtfTrack, dnn::ele_has_closestCtfTrack - e_index_offset);
1689         get(dnn::ele_closestCtfTrack_normalizedChi2 + fill_index_offset_e) =
1690             sp.scale(closestCtfTrack->normalizedChi2(), dnn::ele_closestCtfTrack_normalizedChi2 - e_index_offset);
1691         get(dnn::ele_closestCtfTrack_numberOfValidHits + fill_index_offset_e) =
1692             sp.scale(closestCtfTrack->numberOfValidHits(), dnn::ele_closestCtfTrack_numberOfValidHits - e_index_offset);
1693       }
1694     }
1695   }
1696 
1697   template <typename CandidateCastType, typename TauCastType, typename MuonBlockType>
1698   void createMuonBlockInputs(unsigned idx,
1699                              const TauCastType& tau,
1700                              const size_t tau_index,
1701                              const edm::RefToBase<reco::BaseTau> tau_ref,
1702                              const reco::Vertex& pv,
1703                              double rho,
1704                              const std::vector<pat::Muon>* muons,
1705                              const edm::View<reco::Candidate>& pfCands,
1706                              const Cell& cell_map,
1707                              TauFunc tau_funcs,
1708                              bool is_inner,
1709                              MuonBlockType& muonBlockInputs) {
1710     namespace dnn = dnn_inputs_v2::MuonBlockInputs;
1711     namespace sc = deep_tau::Scaling;
1712     namespace candFunc = candFunc;
1713     using MuonHitMatchV2 = MuonHitMatchV2;
1714     sc::FeatureT ft_global = sc::FeatureT::GridGlobal;
1715     sc::FeatureT ft_PFmu = sc::FeatureT::PfCand_muon;
1716     sc::FeatureT ft_mu = sc::FeatureT::Muon;
1717 
1718     // needed to remap indices from scaling vectors to those from dnn_inputs_v2::MuonBlockInputs
1719     int PFmu_index_offset = scalingParamsMap_->at(std::make_pair(ft_global, false)).mean_.size();
1720     int mu_index_offset = PFmu_index_offset + scalingParamsMap_->at(std::make_pair(ft_PFmu, false)).mean_.size();
1721 
1722     const auto& get = [&](int var_index) -> float& {
1723       if constexpr (std::is_same_v<MuonBlockType, std::vector<float>::iterator>) {
1724         return *(muonBlockInputs + var_index);
1725       } else {
1726         return ((tensorflow::Tensor)muonBlockInputs).tensor<float, 4>()(idx, 0, 0, var_index);
1727       }
1728     };
1729 
1730     const bool valid_index_pf_muon = cell_map.count(CellObjectType::PfCand_muon);
1731     const bool valid_index_muon = cell_map.count(CellObjectType::Muon);
1732 
1733     if (!cell_map.empty()) {
1734       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_global, false));
1735       get(dnn::rho) = sp.scale(rho, dnn::rho);
1736       get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), dnn::tau_pt);
1737       get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), dnn::tau_eta);
1738       get(dnn::tau_inside_ecal_crack) = sp.scale(isInEcalCrack(tau.polarP4().eta()), dnn::tau_inside_ecal_crack);
1739     }
1740     if (valid_index_pf_muon) {
1741       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFmu, is_inner));
1742       size_t index_pf_muon = cell_map.at(CellObjectType::PfCand_muon);
1743       const auto& muon_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_pf_muon));
1744 
1745       get(dnn::pfCand_muon_valid) = sp.scale(valid_index_pf_muon, dnn::pfCand_muon_valid - PFmu_index_offset);
1746       get(dnn::pfCand_muon_rel_pt) =
1747           sp.scale(muon_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_muon_rel_pt - PFmu_index_offset);
1748       get(dnn::pfCand_muon_deta) =
1749           sp.scale(muon_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_muon_deta - PFmu_index_offset);
1750       get(dnn::pfCand_muon_dphi) =
1751           sp.scale(dPhi(tau.polarP4(), muon_cand.polarP4()), dnn::pfCand_muon_dphi - PFmu_index_offset);
1752       get(dnn::pfCand_muon_pvAssociationQuality) = sp.scale<int>(
1753           candFunc::getPvAssocationQuality(muon_cand), dnn::pfCand_muon_pvAssociationQuality - PFmu_index_offset);
1754       get(dnn::pfCand_muon_fromPV) =
1755           sp.scale<int>(candFunc::getFromPV(muon_cand), dnn::pfCand_muon_fromPV - PFmu_index_offset);
1756       get(dnn::pfCand_muon_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(muon_cand, 0.9786588f),
1757                                                               dnn::pfCand_muon_puppiWeight - PFmu_index_offset)
1758                                                    : sp.scale(candFunc::getPuppiWeight(muon_cand, 0.8132477f),
1759                                                               dnn::pfCand_muon_puppiWeight - PFmu_index_offset);
1760       get(dnn::pfCand_muon_charge) = sp.scale(muon_cand.charge(), dnn::pfCand_muon_charge - PFmu_index_offset);
1761       get(dnn::pfCand_muon_lostInnerHits) =
1762           sp.scale<int>(candFunc::getLostInnerHits(muon_cand, 0), dnn::pfCand_muon_lostInnerHits - PFmu_index_offset);
1763       get(dnn::pfCand_muon_numberOfPixelHits) = sp.scale(candFunc::getNumberOfPixelHits(muon_cand, 0),
1764                                                          dnn::pfCand_muon_numberOfPixelHits - PFmu_index_offset);
1765       get(dnn::pfCand_muon_vertex_dx) =
1766           sp.scale(muon_cand.vertex().x() - pv.position().x(), dnn::pfCand_muon_vertex_dx - PFmu_index_offset);
1767       get(dnn::pfCand_muon_vertex_dy) =
1768           sp.scale(muon_cand.vertex().y() - pv.position().y(), dnn::pfCand_muon_vertex_dy - PFmu_index_offset);
1769       get(dnn::pfCand_muon_vertex_dz) =
1770           sp.scale(muon_cand.vertex().z() - pv.position().z(), dnn::pfCand_muon_vertex_dz - PFmu_index_offset);
1771       get(dnn::pfCand_muon_vertex_dx_tauFL) =
1772           sp.scale(muon_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
1773                    dnn::pfCand_muon_vertex_dx_tauFL - PFmu_index_offset);
1774       get(dnn::pfCand_muon_vertex_dy_tauFL) =
1775           sp.scale(muon_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
1776                    dnn::pfCand_muon_vertex_dy_tauFL - PFmu_index_offset);
1777       get(dnn::pfCand_muon_vertex_dz_tauFL) =
1778           sp.scale(muon_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
1779                    dnn::pfCand_muon_vertex_dz_tauFL - PFmu_index_offset);
1780 
1781       const bool hasTrackDetails = candFunc::getHasTrackDetails(muon_cand);
1782       if (hasTrackDetails) {
1783         get(dnn::pfCand_muon_hasTrackDetails) =
1784             sp.scale(hasTrackDetails, dnn::pfCand_muon_hasTrackDetails - PFmu_index_offset);
1785         get(dnn::pfCand_muon_dxy) = sp.scale(candFunc::getTauDxy(muon_cand), dnn::pfCand_muon_dxy - PFmu_index_offset);
1786         get(dnn::pfCand_muon_dxy_sig) = sp.scale(std::abs(candFunc::getTauDxy(muon_cand)) / muon_cand.dxyError(),
1787                                                  dnn::pfCand_muon_dxy_sig - PFmu_index_offset);
1788         get(dnn::pfCand_muon_dz) = sp.scale(candFunc::getTauDz(muon_cand), dnn::pfCand_muon_dz - PFmu_index_offset);
1789         get(dnn::pfCand_muon_dz_sig) = sp.scale(std::abs(candFunc::getTauDz(muon_cand)) / muon_cand.dzError(),
1790                                                 dnn::pfCand_muon_dz_sig - PFmu_index_offset);
1791         get(dnn::pfCand_muon_track_chi2_ndof) =
1792             candFunc::getPseudoTrack(muon_cand).ndof() > 0
1793                 ? sp.scale(candFunc::getPseudoTrack(muon_cand).chi2() / candFunc::getPseudoTrack(muon_cand).ndof(),
1794                            dnn::pfCand_muon_track_chi2_ndof - PFmu_index_offset)
1795                 : 0;
1796         get(dnn::pfCand_muon_track_ndof) =
1797             candFunc::getPseudoTrack(muon_cand).ndof() > 0
1798                 ? sp.scale(candFunc::getPseudoTrack(muon_cand).ndof(), dnn::pfCand_muon_track_ndof - PFmu_index_offset)
1799                 : 0;
1800       }
1801     }
1802     if (valid_index_muon) {
1803       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_mu, is_inner));
1804       size_t index_muon = cell_map.at(CellObjectType::Muon);
1805       const auto& muon = muons->at(index_muon);
1806 
1807       get(dnn::muon_valid) = sp.scale(valid_index_muon, dnn::muon_valid - mu_index_offset);
1808       get(dnn::muon_rel_pt) = sp.scale(muon.polarP4().pt() / tau.polarP4().pt(), dnn::muon_rel_pt - mu_index_offset);
1809       get(dnn::muon_deta) = sp.scale(muon.polarP4().eta() - tau.polarP4().eta(), dnn::muon_deta - mu_index_offset);
1810       get(dnn::muon_dphi) = sp.scale(dPhi(tau.polarP4(), muon.polarP4()), dnn::muon_dphi - mu_index_offset);
1811       get(dnn::muon_dxy) = sp.scale(muon.dB(pat::Muon::PV2D), dnn::muon_dxy - mu_index_offset);
1812       get(dnn::muon_dxy_sig) =
1813           sp.scale(std::abs(muon.dB(pat::Muon::PV2D)) / muon.edB(pat::Muon::PV2D), dnn::muon_dxy_sig - mu_index_offset);
1814 
1815       const bool normalizedChi2_valid = muon.globalTrack().isNonnull() && muon.normChi2() >= 0;
1816       if (normalizedChi2_valid) {
1817         get(dnn::muon_normalizedChi2_valid) =
1818             sp.scale(normalizedChi2_valid, dnn::muon_normalizedChi2_valid - mu_index_offset);
1819         get(dnn::muon_normalizedChi2) = sp.scale(muon.normChi2(), dnn::muon_normalizedChi2 - mu_index_offset);
1820         if (muon.innerTrack().isNonnull())
1821           get(dnn::muon_numberOfValidHits) =
1822               sp.scale(muon.numberOfValidHits(), dnn::muon_numberOfValidHits - mu_index_offset);
1823       }
1824       get(dnn::muon_segmentCompatibility) =
1825           sp.scale(muon.segmentCompatibility(), dnn::muon_segmentCompatibility - mu_index_offset);
1826       get(dnn::muon_caloCompatibility) =
1827           sp.scale(muon.caloCompatibility(), dnn::muon_caloCompatibility - mu_index_offset);
1828 
1829       const bool pfEcalEnergy_valid = muon.pfEcalEnergy() >= 0;
1830       if (pfEcalEnergy_valid) {
1831         get(dnn::muon_pfEcalEnergy_valid) =
1832             sp.scale(pfEcalEnergy_valid, dnn::muon_pfEcalEnergy_valid - mu_index_offset);
1833         get(dnn::muon_rel_pfEcalEnergy) =
1834             sp.scale(muon.pfEcalEnergy() / muon.polarP4().pt(), dnn::muon_rel_pfEcalEnergy - mu_index_offset);
1835       }
1836 
1837       MuonHitMatchV2 hit_match(muon);
1838       static const std::map<int, std::pair<int, int>> muonMatchHitVars = {
1839           {MuonSubdetId::DT, {dnn::muon_n_matches_DT_1, dnn::muon_n_hits_DT_1}},
1840           {MuonSubdetId::CSC, {dnn::muon_n_matches_CSC_1, dnn::muon_n_hits_CSC_1}},
1841           {MuonSubdetId::RPC, {dnn::muon_n_matches_RPC_1, dnn::muon_n_hits_RPC_1}}};
1842 
1843       for (int subdet : hit_match.MuonHitMatchV2::consideredSubdets()) {
1844         const auto& matchHitVar = muonMatchHitVars.at(subdet);
1845         for (int station = MuonHitMatchV2::first_station_id; station <= MuonHitMatchV2::last_station_id; ++station) {
1846           const unsigned n_matches = hit_match.nMatches(subdet, station);
1847           const unsigned n_hits = hit_match.nHits(subdet, station);
1848           get(matchHitVar.first + station - 1) = sp.scale(n_matches, matchHitVar.first + station - 1 - mu_index_offset);
1849           get(matchHitVar.second + station - 1) = sp.scale(n_hits, matchHitVar.second + station - 1 - mu_index_offset);
1850         }
1851       }
1852     }
1853   }
1854 
1855   template <typename CandidateCastType, typename TauCastType, typename HadronBlockType>
1856   void createHadronsBlockInputs(unsigned idx,
1857                                 const TauCastType& tau,
1858                                 const size_t tau_index,
1859                                 const edm::RefToBase<reco::BaseTau> tau_ref,
1860                                 const reco::Vertex& pv,
1861                                 double rho,
1862                                 const edm::View<reco::Candidate>& pfCands,
1863                                 const Cell& cell_map,
1864                                 TauFunc tau_funcs,
1865                                 bool is_inner,
1866                                 HadronBlockType& hadronBlockInputs) {
1867     namespace dnn = dnn_inputs_v2::HadronBlockInputs;
1868     namespace sc = deep_tau::Scaling;
1869     namespace candFunc = candFunc;
1870     sc::FeatureT ft_global = sc::FeatureT::GridGlobal;
1871     sc::FeatureT ft_PFchH = sc::FeatureT::PfCand_chHad;
1872     sc::FeatureT ft_PFnH = sc::FeatureT::PfCand_nHad;
1873 
1874     // needed to remap indices from scaling vectors to those from dnn_inputs_v2::HadronBlockInputs
1875     int PFchH_index_offset = scalingParamsMap_->at(std::make_pair(ft_global, false)).mean_.size();
1876     int PFnH_index_offset = PFchH_index_offset + scalingParamsMap_->at(std::make_pair(ft_PFchH, false)).mean_.size();
1877 
1878     const auto& get = [&](int var_index) -> float& {
1879       if constexpr (std::is_same_v<HadronBlockType, std::vector<float>::iterator>) {
1880         return *(hadronBlockInputs + var_index);
1881       } else {
1882         return ((tensorflow::Tensor)hadronBlockInputs).tensor<float, 4>()(idx, 0, 0, var_index);
1883       }
1884     };
1885 
1886     const bool valid_chH = cell_map.count(CellObjectType::PfCand_chargedHadron);
1887     const bool valid_nH = cell_map.count(CellObjectType::PfCand_neutralHadron);
1888 
1889     if (!cell_map.empty()) {
1890       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_global, false));
1891       get(dnn::rho) = sp.scale(rho, dnn::rho);
1892       get(dnn::tau_pt) = sp.scale(tau.polarP4().pt(), dnn::tau_pt);
1893       get(dnn::tau_eta) = sp.scale(tau.polarP4().eta(), dnn::tau_eta);
1894       get(dnn::tau_inside_ecal_crack) = sp.scale(isInEcalCrack(tau.polarP4().eta()), dnn::tau_inside_ecal_crack);
1895     }
1896     if (valid_chH) {
1897       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFchH, is_inner));
1898       size_t index_chH = cell_map.at(CellObjectType::PfCand_chargedHadron);
1899       const auto& chH_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_chH));
1900 
1901       get(dnn::pfCand_chHad_valid) = sp.scale(valid_chH, dnn::pfCand_chHad_valid - PFchH_index_offset);
1902       get(dnn::pfCand_chHad_rel_pt) =
1903           sp.scale(chH_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_chHad_rel_pt - PFchH_index_offset);
1904       get(dnn::pfCand_chHad_deta) =
1905           sp.scale(chH_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_chHad_deta - PFchH_index_offset);
1906       get(dnn::pfCand_chHad_dphi) =
1907           sp.scale(dPhi(tau.polarP4(), chH_cand.polarP4()), dnn::pfCand_chHad_dphi - PFchH_index_offset);
1908       get(dnn::pfCand_chHad_leadChargedHadrCand) =
1909           sp.scale(&chH_cand == dynamic_cast<const CandidateCastType*>(tau.leadChargedHadrCand().get()),
1910                    dnn::pfCand_chHad_leadChargedHadrCand - PFchH_index_offset);
1911       get(dnn::pfCand_chHad_pvAssociationQuality) = sp.scale<int>(
1912           candFunc::getPvAssocationQuality(chH_cand), dnn::pfCand_chHad_pvAssociationQuality - PFchH_index_offset);
1913       get(dnn::pfCand_chHad_fromPV) =
1914           sp.scale<int>(candFunc::getFromPV(chH_cand), dnn::pfCand_chHad_fromPV - PFchH_index_offset);
1915       const float default_chH_pw_inner = 0.7614090f;
1916       const float default_chH_pw_outer = 0.1974930f;
1917       get(dnn::pfCand_chHad_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(chH_cand, default_chH_pw_inner),
1918                                                                dnn::pfCand_chHad_puppiWeight - PFchH_index_offset)
1919                                                     : sp.scale(candFunc::getPuppiWeight(chH_cand, default_chH_pw_outer),
1920                                                                dnn::pfCand_chHad_puppiWeight - PFchH_index_offset);
1921       get(dnn::pfCand_chHad_puppiWeightNoLep) =
1922           is_inner ? sp.scale(candFunc::getPuppiWeightNoLep(chH_cand, default_chH_pw_inner),
1923                               dnn::pfCand_chHad_puppiWeightNoLep - PFchH_index_offset)
1924                    : sp.scale(candFunc::getPuppiWeightNoLep(chH_cand, default_chH_pw_outer),
1925                               dnn::pfCand_chHad_puppiWeightNoLep - PFchH_index_offset);
1926       get(dnn::pfCand_chHad_charge) = sp.scale(chH_cand.charge(), dnn::pfCand_chHad_charge - PFchH_index_offset);
1927       get(dnn::pfCand_chHad_lostInnerHits) =
1928           sp.scale<int>(candFunc::getLostInnerHits(chH_cand, 0), dnn::pfCand_chHad_lostInnerHits - PFchH_index_offset);
1929       get(dnn::pfCand_chHad_numberOfPixelHits) = sp.scale(candFunc::getNumberOfPixelHits(chH_cand, 0),
1930                                                           dnn::pfCand_chHad_numberOfPixelHits - PFchH_index_offset);
1931       get(dnn::pfCand_chHad_vertex_dx) =
1932           sp.scale(chH_cand.vertex().x() - pv.position().x(), dnn::pfCand_chHad_vertex_dx - PFchH_index_offset);
1933       get(dnn::pfCand_chHad_vertex_dy) =
1934           sp.scale(chH_cand.vertex().y() - pv.position().y(), dnn::pfCand_chHad_vertex_dy - PFchH_index_offset);
1935       get(dnn::pfCand_chHad_vertex_dz) =
1936           sp.scale(chH_cand.vertex().z() - pv.position().z(), dnn::pfCand_chHad_vertex_dz - PFchH_index_offset);
1937       get(dnn::pfCand_chHad_vertex_dx_tauFL) =
1938           sp.scale(chH_cand.vertex().x() - pv.position().x() - tau_funcs.getFlightLength(tau, tau_index).x(),
1939                    dnn::pfCand_chHad_vertex_dx_tauFL - PFchH_index_offset);
1940       get(dnn::pfCand_chHad_vertex_dy_tauFL) =
1941           sp.scale(chH_cand.vertex().y() - pv.position().y() - tau_funcs.getFlightLength(tau, tau_index).y(),
1942                    dnn::pfCand_chHad_vertex_dy_tauFL - PFchH_index_offset);
1943       get(dnn::pfCand_chHad_vertex_dz_tauFL) =
1944           sp.scale(chH_cand.vertex().z() - pv.position().z() - tau_funcs.getFlightLength(tau, tau_index).z(),
1945                    dnn::pfCand_chHad_vertex_dz_tauFL - PFchH_index_offset);
1946 
1947       const bool hasTrackDetails = candFunc::getHasTrackDetails(chH_cand);
1948       if (hasTrackDetails) {
1949         get(dnn::pfCand_chHad_hasTrackDetails) =
1950             sp.scale(hasTrackDetails, dnn::pfCand_chHad_hasTrackDetails - PFchH_index_offset);
1951         get(dnn::pfCand_chHad_dxy) =
1952             sp.scale(candFunc::getTauDxy(chH_cand), dnn::pfCand_chHad_dxy - PFchH_index_offset);
1953         get(dnn::pfCand_chHad_dxy_sig) = sp.scale(std::abs(candFunc::getTauDxy(chH_cand)) / chH_cand.dxyError(),
1954                                                   dnn::pfCand_chHad_dxy_sig - PFchH_index_offset);
1955         get(dnn::pfCand_chHad_dz) = sp.scale(candFunc::getTauDz(chH_cand), dnn::pfCand_chHad_dz - PFchH_index_offset);
1956         get(dnn::pfCand_chHad_dz_sig) = sp.scale(std::abs(candFunc::getTauDz(chH_cand)) / chH_cand.dzError(),
1957                                                  dnn::pfCand_chHad_dz_sig - PFchH_index_offset);
1958         get(dnn::pfCand_chHad_track_chi2_ndof) =
1959             candFunc::getPseudoTrack(chH_cand).ndof() > 0
1960                 ? sp.scale(candFunc::getPseudoTrack(chH_cand).chi2() / candFunc::getPseudoTrack(chH_cand).ndof(),
1961                            dnn::pfCand_chHad_track_chi2_ndof - PFchH_index_offset)
1962                 : 0;
1963         get(dnn::pfCand_chHad_track_ndof) =
1964             candFunc::getPseudoTrack(chH_cand).ndof() > 0
1965                 ? sp.scale(candFunc::getPseudoTrack(chH_cand).ndof(), dnn::pfCand_chHad_track_ndof - PFchH_index_offset)
1966                 : 0;
1967       }
1968       float hcal_fraction = candFunc::getHCalFraction(chH_cand, disable_hcalFraction_workaround_);
1969       get(dnn::pfCand_chHad_hcalFraction) =
1970           sp.scale(hcal_fraction, dnn::pfCand_chHad_hcalFraction - PFchH_index_offset);
1971       get(dnn::pfCand_chHad_rawCaloFraction) =
1972           sp.scale(candFunc::getRawCaloFraction(chH_cand), dnn::pfCand_chHad_rawCaloFraction - PFchH_index_offset);
1973     }
1974     if (valid_nH) {
1975       const sc::ScalingParams& sp = scalingParamsMap_->at(std::make_pair(ft_PFnH, is_inner));
1976       size_t index_nH = cell_map.at(CellObjectType::PfCand_neutralHadron);
1977       const auto& nH_cand = dynamic_cast<const CandidateCastType&>(pfCands.at(index_nH));
1978 
1979       get(dnn::pfCand_nHad_valid) = sp.scale(valid_nH, dnn::pfCand_nHad_valid - PFnH_index_offset);
1980       get(dnn::pfCand_nHad_rel_pt) =
1981           sp.scale(nH_cand.polarP4().pt() / tau.polarP4().pt(), dnn::pfCand_nHad_rel_pt - PFnH_index_offset);
1982       get(dnn::pfCand_nHad_deta) =
1983           sp.scale(nH_cand.polarP4().eta() - tau.polarP4().eta(), dnn::pfCand_nHad_deta - PFnH_index_offset);
1984       get(dnn::pfCand_nHad_dphi) =
1985           sp.scale(dPhi(tau.polarP4(), nH_cand.polarP4()), dnn::pfCand_nHad_dphi - PFnH_index_offset);
1986       get(dnn::pfCand_nHad_puppiWeight) = is_inner ? sp.scale(candFunc::getPuppiWeight(nH_cand, 0.9798355f),
1987                                                               dnn::pfCand_nHad_puppiWeight - PFnH_index_offset)
1988                                                    : sp.scale(candFunc::getPuppiWeight(nH_cand, 0.7813260f),
1989                                                               dnn::pfCand_nHad_puppiWeight - PFnH_index_offset);
1990       get(dnn::pfCand_nHad_puppiWeightNoLep) = is_inner
1991                                                    ? sp.scale(candFunc::getPuppiWeightNoLep(nH_cand, 0.9046796f),
1992                                                               dnn::pfCand_nHad_puppiWeightNoLep - PFnH_index_offset)
1993                                                    : sp.scale(candFunc::getPuppiWeightNoLep(nH_cand, 0.6554860f),
1994                                                               dnn::pfCand_nHad_puppiWeightNoLep - PFnH_index_offset);
1995       float hcal_fraction = candFunc::getHCalFraction(nH_cand, disable_hcalFraction_workaround_);
1996       get(dnn::pfCand_nHad_hcalFraction) = sp.scale(hcal_fraction, dnn::pfCand_nHad_hcalFraction - PFnH_index_offset);
1997     }
1998   }
1999 
2000   static void calculateElectronClusterVars(const pat::Electron* ele, float& elecEe, float& elecEgamma) {
2001     if (ele) {
2002       elecEe = elecEgamma = 0;
2003       auto superCluster = ele->superCluster();
2004       if (superCluster.isNonnull() && superCluster.isAvailable() && superCluster->clusters().isNonnull() &&
2005           superCluster->clusters().isAvailable()) {
2006         for (auto iter = superCluster->clustersBegin(); iter != superCluster->clustersEnd(); ++iter) {
2007           const double energy = (*iter)->energy();
2008           if (iter == superCluster->clustersBegin())
2009             elecEe += energy;
2010           else
2011             elecEgamma += energy;
2012         }
2013       }
2014     } else {
2015       elecEe = elecEgamma = default_value;
2016     }
2017   }
2018 
2019   template <typename CandidateCollection, typename TauCastType>
2020   static void processSignalPFComponents(const TauCastType& tau,
2021                                         const CandidateCollection& candidates,
2022                                         LorentzVectorXYZ& p4_inner,
2023                                         LorentzVectorXYZ& p4_outer,
2024                                         float& pt_inner,
2025                                         float& dEta_inner,
2026                                         float& dPhi_inner,
2027                                         float& m_inner,
2028                                         float& pt_outer,
2029                                         float& dEta_outer,
2030                                         float& dPhi_outer,
2031                                         float& m_outer,
2032                                         float& n_inner,
2033                                         float& n_outer) {
2034     p4_inner = LorentzVectorXYZ(0, 0, 0, 0);
2035     p4_outer = LorentzVectorXYZ(0, 0, 0, 0);
2036     n_inner = 0;
2037     n_outer = 0;
2038 
2039     const double innerSigCone_radius = getInnerSignalConeRadius(tau.pt());
2040     for (const auto& cand : candidates) {
2041       const double dR = reco::deltaR(cand->p4(), tau.leadChargedHadrCand()->p4());
2042       const bool isInside_innerSigCone = dR < innerSigCone_radius;
2043       if (isInside_innerSigCone) {
2044         p4_inner += cand->p4();
2045         ++n_inner;
2046       } else {
2047         p4_outer += cand->p4();
2048         ++n_outer;
2049       }
2050     }
2051 
2052     pt_inner = n_inner != 0 ? p4_inner.Pt() : default_value;
2053     dEta_inner = n_inner != 0 ? dEta(p4_inner, tau.p4()) : default_value;
2054     dPhi_inner = n_inner != 0 ? dPhi(p4_inner, tau.p4()) : default_value;
2055     m_inner = n_inner != 0 ? p4_inner.mass() : default_value;
2056 
2057     pt_outer = n_outer != 0 ? p4_outer.Pt() : default_value;
2058     dEta_outer = n_outer != 0 ? dEta(p4_outer, tau.p4()) : default_value;
2059     dPhi_outer = n_outer != 0 ? dPhi(p4_outer, tau.p4()) : default_value;
2060     m_outer = n_outer != 0 ? p4_outer.mass() : default_value;
2061   }
2062 
2063   template <typename CandidateCollection, typename TauCastType>
2064   static void processIsolationPFComponents(const TauCastType& tau,
2065                                            const CandidateCollection& candidates,
2066                                            LorentzVectorXYZ& p4,
2067                                            float& pt,
2068                                            float& d_eta,
2069                                            float& d_phi,
2070                                            float& m,
2071                                            float& n) {
2072     p4 = LorentzVectorXYZ(0, 0, 0, 0);
2073     n = 0;
2074 
2075     for (const auto& cand : candidates) {
2076       p4 += cand->p4();
2077       ++n;
2078     }
2079 
2080     pt = n != 0 ? p4.Pt() : default_value;
2081     d_eta = n != 0 ? dEta(p4, tau.p4()) : default_value;
2082     d_phi = n != 0 ? dPhi(p4, tau.p4()) : default_value;
2083     m = n != 0 ? p4.mass() : default_value;
2084   }
2085 
2086   static double getInnerSignalConeRadius(double pt) {
2087     static constexpr double min_pt = 30., min_radius = 0.05, cone_opening_coef = 3.;
2088     // This is equivalent of the original formula (std::max(std::min(0.1, 3.0/pt), 0.05)
2089     return std::max(cone_opening_coef / std::max(pt, min_pt), min_radius);
2090   }
2091 
2092   // Copied from https://github.com/cms-sw/cmssw/blob/CMSSW_9_4_X/RecoTauTag/RecoTau/plugins/PATTauDiscriminationByMVAIsolationRun2.cc#L218
2093   template <typename TauCastType>
2094   static bool calculateGottfriedJacksonAngleDifference(const TauCastType& tau,
2095                                                        const size_t tau_index,
2096                                                        double& gj_diff,
2097                                                        TauFunc tau_funcs) {
2098     if (tau_funcs.getHasSecondaryVertex(tau, tau_index)) {
2099       static constexpr double mTau = 1.77682;
2100       const double mAOne = tau.p4().M();
2101       const double pAOneMag = tau.p();
2102       const double argumentThetaGJmax = (std::pow(mTau, 2) - std::pow(mAOne, 2)) / (2 * mTau * pAOneMag);
2103       const double argumentThetaGJmeasured = tau.p4().Vect().Dot(tau_funcs.getFlightLength(tau, tau_index)) /
2104                                              (pAOneMag * tau_funcs.getFlightLength(tau, tau_index).R());
2105       if (std::abs(argumentThetaGJmax) <= 1. && std::abs(argumentThetaGJmeasured) <= 1.) {
2106         double thetaGJmax = std::asin(argumentThetaGJmax);
2107         double thetaGJmeasured = std::acos(argumentThetaGJmeasured);
2108         gj_diff = thetaGJmeasured - thetaGJmax;
2109         return true;
2110       }
2111     }
2112     return false;
2113   }
2114 
2115   template <typename TauCastType>
2116   static float calculateGottfriedJacksonAngleDifference(const TauCastType& tau,
2117                                                         const size_t tau_index,
2118                                                         TauFunc tau_funcs) {
2119     double gj_diff;
2120     if (calculateGottfriedJacksonAngleDifference(tau, tau_index, gj_diff, tau_funcs))
2121       return static_cast<float>(gj_diff);
2122     return default_value;
2123   }
2124 
2125   static bool isInEcalCrack(double eta) {
2126     const double abs_eta = std::abs(eta);
2127     return abs_eta > 1.46 && abs_eta < 1.558;
2128   }
2129 
2130   template <typename TauCastType>
2131   static const pat::Electron* findMatchedElectron(const TauCastType& tau,
2132                                                   const std::vector<pat::Electron>* electrons,
2133                                                   double deltaR) {
2134     const double dR2 = deltaR * deltaR;
2135     const pat::Electron* matched_ele = nullptr;
2136     for (const auto& ele : *electrons) {
2137       if (reco::deltaR2(tau.p4(), ele.p4()) < dR2 && (!matched_ele || matched_ele->pt() < ele.pt())) {
2138         matched_ele = &ele;
2139       }
2140     }
2141     return matched_ele;
2142   }
2143 
2144 protected:
2145   edm::EDGetTokenT<TauCollection> tausToken_;
2146   edm::EDGetTokenT<CandidateCollection> pfcandToken_;
2147   edm::EDGetTokenT<reco::VertexCollection> vtxToken_;
2148   std::map<std::string, WPList> workingPoints_;
2149   const bool is_online_;
2150   IDOutputCollection idoutputs_;
2151 
2152   const std::map<BasicDiscriminator, std::string> stringFromDiscriminator_{
2153       {BasicDiscriminator::ChargedIsoPtSum, "ChargedIsoPtSum"},
2154       {BasicDiscriminator::NeutralIsoPtSum, "NeutralIsoPtSum"},
2155       {BasicDiscriminator::NeutralIsoPtSumWeight, "NeutralIsoPtSumWeight"},
2156       {BasicDiscriminator::FootprintCorrection, "TauFootprintCorrection"},
2157       {BasicDiscriminator::PhotonPtSumOutsideSignalCone, "PhotonPtSumOutsideSignalCone"},
2158       {BasicDiscriminator::PUcorrPtSum, "PUcorrPtSum"}};
2159   const std::vector<BasicDiscriminator> requiredBasicDiscriminators_{BasicDiscriminator::ChargedIsoPtSum,
2160                                                                      BasicDiscriminator::NeutralIsoPtSum,
2161                                                                      BasicDiscriminator::NeutralIsoPtSumWeight,
2162                                                                      BasicDiscriminator::PhotonPtSumOutsideSignalCone,
2163                                                                      BasicDiscriminator::PUcorrPtSum};
2164   const std::vector<BasicDiscriminator> requiredBasicDiscriminatorsdR03_{
2165       BasicDiscriminator::ChargedIsoPtSum,
2166       BasicDiscriminator::NeutralIsoPtSum,
2167       BasicDiscriminator::NeutralIsoPtSumWeight,
2168       BasicDiscriminator::PhotonPtSumOutsideSignalCone,
2169       BasicDiscriminator::FootprintCorrection};
2170 
2171   edm::EDGetTokenT<std::vector<pat::Electron>> electrons_token_;
2172   edm::EDGetTokenT<std::vector<pat::Muon>> muons_token_;
2173   edm::EDGetTokenT<double> rho_token_;
2174   edm::EDGetTokenT<reco::TauDiscriminatorContainer> basicTauDiscriminators_inputToken_;
2175   edm::EDGetTokenT<reco::TauDiscriminatorContainer> basicTauDiscriminatorsdR03_inputToken_;
2176   edm::EDGetTokenT<edm::AssociationVector<reco::PFTauRefProd, std::vector<reco::PFTauTransverseImpactParameterRef>>>
2177       pfTauTransverseImpactParameters_token_;
2178   std::string input_layer_, output_layer_;
2179   const unsigned year_;
2180   const unsigned version_;
2181   const unsigned sub_version_;
2182   const int debug_level;
2183   const bool disable_dxy_pca_;
2184   const bool disable_hcalFraction_workaround_;
2185   const bool disable_CellIndex_workaround_;
2186   const std::map<std::pair<deep_tau::Scaling::FeatureT, bool>, deep_tau::Scaling::ScalingParams>* scalingParamsMap_;
2187   const bool save_inputs_;
2188   std::ofstream* json_file_;
2189   bool is_first_block_;
2190   int file_counter_;
2191   std::vector<int> tauInputs_indices_;
2192 
2193   //boolean to check if discriminator indices are already mapped
2194   bool discrIndicesMapped_ = false;
2195   std::map<BasicDiscriminator, size_t> basicDiscrIndexMap_;
2196   std::map<BasicDiscriminator, size_t> basicDiscrdR03IndexMap_;
2197 };
2198 
2199 #endif