File indexing completed on 2023-03-17 11:17:03
0001 #ifndef ImpactParameter_TemplatedJetProbabilityComputer_h
0002 #define ImpactParameter_TemplatedJetProbabilityComputer_h
0003
0004 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0005 #include "DataFormats/TrackReco/interface/Track.h"
0006 #include "DataFormats/BTauReco/interface/TrackProbabilityTagInfo.h"
0007 #include "DataFormats/BTauReco/interface/TrackIPTagInfo.h"
0008 #include "DataFormats/VertexReco/interface/Vertex.h"
0009 #include "Math/GenVector/VectorUtil.h"
0010 #include "RecoBTau/JetTagComputer/interface/JetTagComputer.h"
0011
0012 #include <iostream>
0013
0014 template <class Container, class Base>
0015 class TemplatedJetProbabilityComputer : public JetTagComputer {
0016 public:
0017 using Tokens = void;
0018
0019 typedef reco::IPTagInfo<Container, Base> TagInfo;
0020
0021 TemplatedJetProbabilityComputer(const edm::ParameterSet& parameters) {
0022 m_ipType = parameters.getParameter<int>("impactParameterType");
0023 m_minTrackProb = parameters.getParameter<double>("minimumProbability");
0024 m_deltaR = parameters.getParameter<double>("deltaR");
0025 m_trackSign = parameters.getParameter<int>("trackIpSign");
0026 m_cutMaxDecayLen = parameters.getParameter<double>("maximumDecayLength");
0027 m_cutMaxDistToAxis = parameters.getParameter<double>("maximumDistanceToJetAxis");
0028
0029
0030
0031 std::string trackQualityType = parameters.getParameter<std::string>("trackQualityClass");
0032 m_trackQuality = reco::TrackBase::qualityByName(trackQualityType);
0033 m_useAllQualities = false;
0034 if (trackQualityType == "any" || trackQualityType == "Any" || trackQualityType == "ANY")
0035 m_useAllQualities = true;
0036
0037 useVariableJTA_ = parameters.getParameter<bool>("useVariableJTA");
0038 if (useVariableJTA_)
0039 varJTApars = {parameters.getParameter<double>("a_dR"),
0040 parameters.getParameter<double>("b_dR"),
0041 parameters.getParameter<double>("a_pT"),
0042 parameters.getParameter<double>("b_pT"),
0043 parameters.getParameter<double>("min_pT"),
0044 parameters.getParameter<double>("max_pT"),
0045 parameters.getParameter<double>("min_pT_dRcut"),
0046 parameters.getParameter<double>("max_pT_dRcut"),
0047 parameters.getParameter<double>("max_pT_trackPTcut")};
0048
0049 uses("ipTagInfos");
0050 }
0051
0052 float discriminator(const TagInfoHelper& ti) const override {
0053 const TagInfo& tkip = ti.get<TagInfo>();
0054 const Container& tracks(tkip.selectedTracks());
0055 const std::vector<float>& allProbabilities((tkip.probabilities(m_ipType)));
0056 const std::vector<reco::btag::TrackIPData>& impactParameters((tkip.impactParameterData()));
0057
0058 if (tkip.primaryVertex().isNull())
0059 return 0;
0060
0061 GlobalPoint pv(tkip.primaryVertex()->position().x(),
0062 tkip.primaryVertex()->position().y(),
0063 tkip.primaryVertex()->position().z());
0064
0065 std::vector<float> probabilities;
0066 int i = 0;
0067 for (std::vector<float>::const_iterator it = allProbabilities.begin(); it != allProbabilities.end(); ++it, i++) {
0068 if (fabs(impactParameters[i].distanceToJetAxis.value()) < m_cutMaxDistToAxis &&
0069 (impactParameters[i].closestToJetAxis - pv).mag() < m_cutMaxDecayLen &&
0070 (m_useAllQualities == true ||
0071 reco::btag::toTrack(tracks[i])->quality(m_trackQuality))
0072 ) {
0073 float p;
0074 if (m_trackSign == 0) {
0075 if (*it >= 0) {
0076 p = *it / 2.;
0077 } else {
0078 p = 1. + *it / 2.;
0079 }
0080 } else if (m_trackSign > 0) {
0081 if (*it >= 0)
0082 p = *it;
0083 else
0084 continue;
0085 } else {
0086 if (*it <= 0)
0087 p = -*it;
0088 else
0089 continue;
0090 }
0091 if (useVariableJTA_) {
0092 if (tkip.variableJTA(varJTApars)[i])
0093 probabilities.push_back(p);
0094 } else {
0095 if (m_deltaR <= 0 ||
0096 ROOT::Math::VectorUtil::DeltaR((*tkip.jet()).p4().Vect(), (*tracks[i]).momentum()) < m_deltaR)
0097 probabilities.push_back(p);
0098 }
0099 }
0100 }
0101 return jetProbability(probabilities);
0102 }
0103
0104 double jetProbability(const std::vector<float>& v) const {
0105 int ngoodtracks = v.size();
0106 double SumJet = 0.;
0107
0108 for (std::vector<float>::const_iterator q = v.begin(); q != v.end(); q++) {
0109 SumJet += (*q > m_minTrackProb) ? log(*q) : log(m_minTrackProb);
0110 }
0111
0112 double ProbJet;
0113 double Loginvlog = 0;
0114
0115 if (SumJet < 0.) {
0116 if (ngoodtracks >= 2) {
0117 Loginvlog = log(-SumJet);
0118 }
0119 double Prob = 1.;
0120 double lfact = 1.;
0121 for (int l = 1; l != ngoodtracks; l++) {
0122 lfact *= l;
0123 Prob += exp(l * Loginvlog - log(1. * lfact));
0124 }
0125 double LogProb = log(Prob);
0126 ProbJet = std::min(exp(std::max(LogProb + SumJet, -30.)), 1.);
0127 } else {
0128 ProbJet = 1.;
0129 }
0130 if (ProbJet > 1)
0131 std::cout << "ProbJet too high: " << ProbJet << std::endl;
0132
0133
0134
0135 return -log10(ProbJet) / 4.;
0136 }
0137
0138 private:
0139 bool useVariableJTA_;
0140 reco::btag::variableJTAParameters varJTApars;
0141 double m_minTrackProb;
0142 int m_ipType;
0143 double m_deltaR;
0144 int m_trackSign;
0145 double m_cutMaxDecayLen;
0146 double m_cutMaxDistToAxis;
0147 reco::TrackBase::TrackQuality m_trackQuality;
0148 bool m_useAllQualities;
0149 };
0150
0151 #endif