Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-05-18 03:27:30

0001 #ifndef RecoTauTag_RecoTau_DeepTauBase_h
0002 #define RecoTauTag_RecoTau_DeepTauBase_h
0003 
0004 /*
0005  * \class DeepTauBase
0006  *
0007  * Definition of the base class for tau identification using Deep NN.
0008  *
0009  * \author Konstantin Androsov, INFN Pisa
0010  * \author Maria Rosaria Di Domenico, University of Siena & INFN Pisa
0011  */
0012 
0013 #include <Math/VectorUtil.h>
0014 #include "FWCore/Framework/interface/stream/EDProducer.h"
0015 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0016 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0017 #include "tensorflow/core/util/memmapped_file_system.h"
0018 #include "DataFormats/PatCandidates/interface/Electron.h"
0019 #include "DataFormats/PatCandidates/interface/Muon.h"
0020 #include "DataFormats/PatCandidates/interface/Tau.h"
0021 #include "DataFormats/TauReco/interface/TauDiscriminatorContainer.h"
0022 #include "DataFormats/TauReco/interface/PFTauDiscriminator.h"
0023 #include "DataFormats/PatCandidates/interface/PATTauDiscriminator.h"
0024 #include "CommonTools/Utils/interface/StringObjectFunction.h"
0025 #include "RecoTauTag/RecoTau/interface/PFRecoTauClusterVariables.h"
0026 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0027 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0028 #include "DataFormats/Common/interface/View.h"
0029 #include "DataFormats/Common/interface/RefToBase.h"
0030 #include "DataFormats/Provenance/interface/ProductProvenance.h"
0031 #include "DataFormats/Provenance/interface/ProcessHistoryID.h"
0032 #include "FWCore/Common/interface/Provenance.h"
0033 #include <TF1.h>
0034 #include <map>
0035 
0036 namespace deep_tau {
0037 
0038   class TauWPThreshold {
0039   public:
0040     explicit TauWPThreshold(const std::string& cut_str);
0041     double operator()(const reco::BaseTau& tau, bool isPFTau) const;
0042 
0043   private:
0044     std::unique_ptr<TF1> fn_;
0045     double value_;
0046   };
0047 
0048   class DeepTauCache {
0049   public:
0050     using GraphPtr = std::shared_ptr<tensorflow::GraphDef>;
0051 
0052     DeepTauCache(const std::map<std::string, std::string>& graph_names, bool mem_mapped);
0053     ~DeepTauCache();
0054 
0055     // A Session allows concurrent calls to Run(), though a Session must
0056     // be created / extended by a single thread.
0057     tensorflow::Session& getSession(const std::string& name = "") const { return *sessions_.at(name); }
0058     const tensorflow::GraphDef& getGraph(const std::string& name = "") const { return *graphs_.at(name); }
0059 
0060   private:
0061     std::map<std::string, GraphPtr> graphs_;
0062     std::map<std::string, tensorflow::Session*> sessions_;
0063     std::map<std::string, std::unique_ptr<tensorflow::MemmappedEnv>> memmappedEnv_;
0064   };
0065 
0066   class DeepTauBase : public edm::stream::EDProducer<edm::GlobalCache<DeepTauCache>> {
0067   public:
0068     using TauDiscriminator = reco::TauDiscriminatorContainer;
0069     using TauCollection = edm::View<reco::BaseTau>;
0070     using CandidateCollection = edm::View<reco::Candidate>;
0071     using TauRef = edm::Ref<TauCollection>;
0072     using TauRefProd = edm::RefProd<TauCollection>;
0073     using ElectronCollection = pat::ElectronCollection;
0074     using MuonCollection = pat::MuonCollection;
0075     using LorentzVectorXYZ = ROOT::Math::LorentzVector<ROOT::Math::PxPyPzE4D<double>>;
0076     using Cutter = TauWPThreshold;
0077     using CutterPtr = std::unique_ptr<Cutter>;
0078     using WPList = std::vector<CutterPtr>;
0079 
0080     struct Output {
0081       std::vector<size_t> num_, den_;
0082 
0083       Output(const std::vector<size_t>& num, const std::vector<size_t>& den) : num_(num), den_(den) {}
0084 
0085       std::unique_ptr<TauDiscriminator> get_value(const edm::Handle<TauCollection>& taus,
0086                                                   const tensorflow::Tensor& pred,
0087                                                   const WPList* working_points,
0088                                                   bool is_online) const;
0089     };
0090 
0091     using OutputCollection = std::map<std::string, Output>;
0092 
0093     DeepTauBase(const edm::ParameterSet& cfg, const OutputCollection& outputs, const DeepTauCache* cache);
0094     ~DeepTauBase() override {}
0095 
0096     void produce(edm::Event& event, const edm::EventSetup& es) override;
0097 
0098     static std::unique_ptr<DeepTauCache> initializeGlobalCache(const edm::ParameterSet& cfg);
0099     static void globalEndJob(const DeepTauCache* cache) {}
0100 
0101     template <typename ConsumeType>
0102     struct TauDiscInfo {
0103       edm::InputTag label;
0104       edm::Handle<ConsumeType> handle;
0105       edm::EDGetTokenT<ConsumeType> disc_token;
0106       double cut;
0107       void fill(const edm::Event& evt) { evt.getByToken(disc_token, handle); }
0108     };
0109 
0110     // select boolean operation on prediscriminants (and = 0x01, or = 0x00)
0111     uint8_t andPrediscriminants_;
0112     std::vector<TauDiscInfo<pat::PATTauDiscriminator>> patPrediscriminants_;
0113     std::vector<TauDiscInfo<reco::PFTauDiscriminator>> recoPrediscriminants_;
0114 
0115     enum BasicDiscriminator {
0116       ChargedIsoPtSum,
0117       NeutralIsoPtSum,
0118       NeutralIsoPtSumWeight,
0119       FootprintCorrection,
0120       PhotonPtSumOutsideSignalCone,
0121       PUcorrPtSum
0122     };
0123 
0124   private:
0125     virtual tensorflow::Tensor getPredictions(edm::Event& event, edm::Handle<TauCollection> taus) = 0;
0126     virtual void createOutputs(edm::Event& event, const tensorflow::Tensor& pred, edm::Handle<TauCollection> taus);
0127 
0128   protected:
0129     edm::EDGetTokenT<TauCollection> tausToken_;
0130     edm::EDGetTokenT<CandidateCollection> pfcandToken_;
0131     edm::EDGetTokenT<reco::VertexCollection> vtxToken_;
0132     std::map<std::string, WPList> workingPoints_;
0133     const bool is_online_;
0134     OutputCollection outputs_;
0135     const DeepTauCache* cache_;
0136 
0137     static const std::map<BasicDiscriminator, std::string> stringFromDiscriminator_;
0138     static const std::vector<BasicDiscriminator> requiredBasicDiscriminators_;
0139     static const std::vector<BasicDiscriminator> requiredBasicDiscriminatorsdR03_;
0140   };
0141 
0142 }  // namespace deep_tau
0143 
0144 #endif