Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:27:47

0001 
0002 /*
0003  * \class PATTauDiscriminationByMVAIsolationRun2
0004  * 
0005  * MVA based discriminator against jet -> tau fakes
0006  * 
0007  * Adopted from RecoTauTag/RecoTau/plugins/PFRecoTauDiscriminationByMVAIsolationRun2.cc
0008  * to enable computation of MVA isolation on MiniAOD
0009  * 
0010  * \author Alexander Nehrkorn, RWTH Aachen
0011  */
0012 
0013 // todo 1: remove leadingTrackChi2 as input variable from:
0014 //           - here
0015 //           - TauPFEssential
0016 //           - PFRecoTauDiscriminationByMVAIsolationRun2
0017 //           - Training of BDT
0018 
0019 #include "RecoTauTag/RecoTau/interface/TauDiscriminationProducerBase.h"
0020 
0021 #include "FWCore/Framework/interface/Event.h"
0022 #include "FWCore/Framework/interface/EventSetup.h"
0023 #include "FWCore/Utilities/interface/InputTag.h"
0024 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0025 #include "FWCore/ParameterSet/interface/FileInPath.h"
0026 
0027 #include "FWCore/Utilities/interface/Exception.h"
0028 
0029 #include <FWCore/ParameterSet/interface/ConfigurationDescriptions.h>
0030 #include <FWCore/ParameterSet/interface/ParameterSetDescription.h>
0031 
0032 #include "DataFormats/Candidate/interface/Candidate.h"
0033 #include "DataFormats/PatCandidates/interface/Tau.h"
0034 #include "DataFormats/PatCandidates/interface/PATTauDiscriminator.h"
0035 #include "DataFormats/Math/interface/deltaR.h"
0036 #include "RecoTauTag/RecoTau/interface/PFRecoTauClusterVariables.h"
0037 
0038 #include "CondFormats/GBRForest/interface/GBRForest.h"
0039 #include "CondFormats/DataRecord/interface/GBRWrapperRcd.h"
0040 
0041 #include <TFile.h>
0042 
0043 #include <iostream>
0044 
0045 using namespace pat;
0046 
0047 namespace {
0048   const GBRForest* loadMVAfromFile(const edm::FileInPath& inputFileName,
0049                                    const std::string& mvaName,
0050                                    std::vector<TFile*>& inputFilesToDelete) {
0051     if (inputFileName.location() == edm::FileInPath::Unknown)
0052       throw cms::Exception("PATTauDiscriminationByIsolationMVARun2::loadMVA")
0053           << " Failed to find File = " << inputFileName << " !!\n";
0054     TFile* inputFile = new TFile(inputFileName.fullPath().data());
0055 
0056     //const GBRForest* mva = dynamic_cast<GBRForest*>(inputFile->Get(mvaName.data())); // CV: dynamic_cast<GBRForest*> fails for some reason ?!
0057     const GBRForest* mva = (GBRForest*)inputFile->Get(mvaName.data());
0058     if (!mva)
0059       throw cms::Exception("PATTauDiscriminationByIsolationMVARun2::loadMVA")
0060           << " Failed to load MVA = " << mvaName.data() << " from file = " << inputFileName.fullPath().data()
0061           << " !!\n";
0062 
0063     inputFilesToDelete.push_back(inputFile);
0064 
0065     return mva;
0066   }
0067 }  // namespace
0068 
0069 namespace reco {
0070   namespace tau {
0071 
0072     class PATTauDiscriminationByMVAIsolationRun2 : public PATTauDiscriminationContainerProducerBase {
0073     public:
0074       explicit PATTauDiscriminationByMVAIsolationRun2(const edm::ParameterSet& cfg)
0075           : PATTauDiscriminationContainerProducerBase(cfg),
0076             moduleLabel_(cfg.getParameter<std::string>("@module_label")),
0077             mvaReader_(nullptr),
0078             mvaInput_(nullptr) {
0079         mvaName_ = cfg.getParameter<std::string>("mvaName");
0080         loadMVAfromDB_ = cfg.getParameter<bool>("loadMVAfromDB");
0081         if (!loadMVAfromDB_) {
0082           inputFileName_ = cfg.getParameter<edm::FileInPath>("inputFileName");
0083         } else {
0084           mvaToken_ = esConsumes(edm::ESInputTag{"", mvaName_});
0085         }
0086         std::string mvaOpt_string = cfg.getParameter<std::string>("mvaOpt");
0087         if (mvaOpt_string == "oldDMwoLT")
0088           mvaOpt_ = kOldDMwoLT;
0089         else if (mvaOpt_string == "oldDMwLT")
0090           mvaOpt_ = kOldDMwLT;
0091         else if (mvaOpt_string == "newDMwoLT")
0092           mvaOpt_ = kNewDMwoLT;
0093         else if (mvaOpt_string == "newDMwLT")
0094           mvaOpt_ = kNewDMwLT;
0095         else if (mvaOpt_string == "DBoldDMwLT")
0096           mvaOpt_ = kDBoldDMwLT;
0097         else if (mvaOpt_string == "DBnewDMwLT")
0098           mvaOpt_ = kDBnewDMwLT;
0099         else if (mvaOpt_string == "PWoldDMwLT")
0100           mvaOpt_ = kPWoldDMwLT;
0101         else if (mvaOpt_string == "PWnewDMwLT")
0102           mvaOpt_ = kPWnewDMwLT;
0103         else if (mvaOpt_string == "DBoldDMwLTwGJ")
0104           mvaOpt_ = kDBoldDMwLTwGJ;
0105         else if (mvaOpt_string == "DBnewDMwLTwGJ")
0106           mvaOpt_ = kDBnewDMwLTwGJ;
0107         else if (mvaOpt_string == "DBnewDMwLTwGJPhase2")
0108           mvaOpt_ = kDBnewDMwLTwGJPhase2;
0109         else
0110           throw cms::Exception("PATTauDiscriminationByMVAIsolationRun2")
0111               << " Invalid Configuration Parameter 'mvaOpt' = " << mvaOpt_string << " !!\n";
0112 
0113         if (mvaOpt_ == kOldDMwoLT || mvaOpt_ == kNewDMwoLT)
0114           mvaInput_ = new float[6];
0115         else if (mvaOpt_ == kOldDMwLT || mvaOpt_ == kNewDMwLT)
0116           mvaInput_ = new float[12];
0117         else if (mvaOpt_ == kDBoldDMwLT || mvaOpt_ == kDBnewDMwLT || mvaOpt_ == kPWoldDMwLT || mvaOpt_ == kPWnewDMwLT ||
0118                  mvaOpt_ == kDBoldDMwLTwGJ || mvaOpt_ == kDBnewDMwLTwGJ)
0119           mvaInput_ = new float[23];
0120         else if (mvaOpt_ == kDBnewDMwLTwGJPhase2)
0121           mvaInput_ = new float[30];
0122         else
0123           assert(0);
0124 
0125         chargedIsoPtSums_ = cfg.getParameter<std::string>("srcChargedIsoPtSum");
0126         neutralIsoPtSums_ = cfg.getParameter<std::string>("srcNeutralIsoPtSum");
0127         puCorrPtSums_ = cfg.getParameter<std::string>("srcPUcorrPtSum");
0128         photonPtSumOutsideSignalCone_ = cfg.getParameter<std::string>("srcPhotonPtSumOutsideSignalCone");
0129         footprintCorrection_ = cfg.getParameter<std::string>("srcFootprintCorrection");
0130 
0131         verbosity_ = cfg.getParameter<int>("verbosity");
0132       }
0133 
0134       void beginEvent(const edm::Event&, const edm::EventSetup&) override;
0135 
0136       reco::SingleTauDiscriminatorContainer discriminate(const TauRef&) const override;
0137 
0138       ~PATTauDiscriminationByMVAIsolationRun2() override {
0139         if (!loadMVAfromDB_)
0140           delete mvaReader_;
0141         delete[] mvaInput_;
0142         for (std::vector<TFile*>::iterator it = inputFilesToDelete_.begin(); it != inputFilesToDelete_.end(); ++it) {
0143           delete (*it);
0144         }
0145       }
0146 
0147       static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0148 
0149     private:
0150       std::string moduleLabel_;
0151 
0152       std::string mvaName_;
0153       edm::ESGetToken<GBRForest, GBRWrapperRcd> mvaToken_;
0154       bool loadMVAfromDB_;
0155       edm::FileInPath inputFileName_;
0156       const GBRForest* mvaReader_;
0157       int mvaOpt_;
0158       float* mvaInput_;
0159 
0160       std::string chargedIsoPtSums_;
0161       std::string neutralIsoPtSums_;
0162       std::string puCorrPtSums_;
0163       std::string photonPtSumOutsideSignalCone_;
0164       std::string footprintCorrection_;
0165 
0166       edm::Handle<TauCollection> taus_;
0167       std::vector<TFile*> inputFilesToDelete_;
0168 
0169       int verbosity_;
0170     };
0171 
0172     void PATTauDiscriminationByMVAIsolationRun2::beginEvent(const edm::Event& evt, const edm::EventSetup& es) {
0173       if (!mvaReader_) {
0174         if (loadMVAfromDB_) {
0175           mvaReader_ = &es.getData(mvaToken_);
0176         } else {
0177           mvaReader_ = loadMVAfromFile(inputFileName_, mvaName_, inputFilesToDelete_);
0178         }
0179       }
0180 
0181       evt.getByToken(Tau_token, taus_);
0182     }
0183 
0184     reco::SingleTauDiscriminatorContainer PATTauDiscriminationByMVAIsolationRun2::discriminate(const TauRef& tau) const {
0185       // CV: define dummy category index in order to use RecoTauDiscriminantCutMultiplexer module to appy WP cuts
0186       reco::SingleTauDiscriminatorContainer result;
0187       result.rawValues = {-1.};
0188 
0189       // CV: computation of MVA value requires presence of leading charged hadron
0190       if (tau->leadChargedHadrCand().isNull()) {
0191         result.rawValues.at(0) = 0.;
0192         return result;
0193       }
0194 
0195       if (reco::tau::fillIsoMVARun2Inputs(mvaInput_,
0196                                           *tau,
0197                                           mvaOpt_,
0198                                           chargedIsoPtSums_,
0199                                           neutralIsoPtSums_,
0200                                           puCorrPtSums_,
0201                                           photonPtSumOutsideSignalCone_,
0202                                           footprintCorrection_)) {
0203         double mvaValue = mvaReader_->GetClassifier(mvaInput_);
0204         if (verbosity_) {
0205           edm::LogPrint("PATTauDiscByMVAIsolRun2") << "<PATTauDiscriminationByMVAIsolationRun2::discriminate>:";
0206           edm::LogPrint("PATTauDiscByMVAIsolRun2") << " tau: Pt = " << tau->pt() << ", eta = " << tau->eta();
0207           edm::LogPrint("PATTauDiscByMVAIsolRun2")
0208               << " isolation: charged = " << tau->tauID(chargedIsoPtSums_)
0209               << ", neutral = " << tau->tauID(neutralIsoPtSums_) << ", PUcorr = " << tau->tauID(puCorrPtSums_);
0210           edm::LogPrint("PATTauDiscByMVAIsolRun2") << " decay mode = " << tau->decayMode();
0211           edm::LogPrint("PATTauDiscByMVAIsolRun2")
0212               << " impact parameter: distance = " << tau->dxy() << ", significance = " << tau->dxy_Sig();
0213           edm::LogPrint("PATTauDiscByMVAIsolRun2") << " has decay vertex = " << tau->hasSecondaryVertex() << ":"
0214                                                    << ", significance = " << tau->flightLengthSig();
0215           edm::LogPrint("PATTauDiscByMVAIsolRun2") << "--> mvaValue = " << mvaValue;
0216         }
0217         result.rawValues.at(0) = mvaValue;
0218       }
0219       return result;
0220     }
0221 
0222     void PATTauDiscriminationByMVAIsolationRun2::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0223       // patTauDiscriminationByMVAIsolationRun2
0224       edm::ParameterSetDescription desc;
0225 
0226       desc.add<std::string>("mvaName");
0227       desc.add<bool>("loadMVAfromDB");
0228       desc.addOptional<edm::FileInPath>("inputFileName");
0229       desc.add<std::string>("mvaOpt");
0230 
0231       desc.add<std::string>("srcChargedIsoPtSum");
0232       desc.add<std::string>("srcNeutralIsoPtSum");
0233       desc.add<std::string>("srcPUcorrPtSum");
0234       desc.add<std::string>("srcPhotonPtSumOutsideSignalCone");
0235       desc.add<std::string>("srcFootprintCorrection");
0236       desc.add<int>("verbosity", 0);
0237 
0238       fillProducerDescriptions(desc);  // inherited from the base
0239 
0240       descriptions.add("patTauDiscriminationByMVAIsolationRun2", desc);
0241     }
0242 
0243     DEFINE_FWK_MODULE(PATTauDiscriminationByMVAIsolationRun2);
0244 
0245   }  // namespace tau
0246 }  // namespace reco