File indexing completed on 2024-04-06 12:24:24
0001 #include <algorithm>
0002 #include <iostream>
0003 #include <sstream>
0004 #include <string>
0005 #include <memory>
0006 #include <vector>
0007 #include <map>
0008
0009 #include "FWCore/Utilities/interface/Exception.h"
0010 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0011 #include "DataFormats/Common/interface/RefToBase.h"
0012 #include "DataFormats/JetReco/interface/Jet.h"
0013 #include "RecoBTau/JetTagComputer/interface/JetTagComputerRecord.h"
0014 #include "RecoBTag/Combined/interface/CombinedMVAV2JetTagComputer.h"
0015
0016 using namespace reco;
0017
0018 CombinedMVAV2JetTagComputer::Tokens::Tokens(const edm::ParameterSet ¶ms, edm::ESConsumesCollector &&cc) {
0019 if (params.getParameter<bool>("useCondDB")) {
0020 gbrForest_ = cc.consumes(edm::ESInputTag{
0021 "", params.existsAs<std::string>("gbrForestLabel") ? params.getParameter<std::string>("gbrForestLabel") : ""});
0022 }
0023 const auto &inputComputerNames = params.getParameter<std::vector<std::string> >("jetTagComputers");
0024 computers_.resize(inputComputerNames.size());
0025 for (size_t i = 0; i < inputComputerNames.size(); ++i) {
0026 computers_[i] = cc.consumes(edm::ESInputTag{"", inputComputerNames[i]});
0027 }
0028 }
0029
0030 CombinedMVAV2JetTagComputer::CombinedMVAV2JetTagComputer(const edm::ParameterSet ¶ms, Tokens tokens)
0031 : mvaName(params.getParameter<std::string>("mvaName")),
0032 variables(params.getParameter<std::vector<std::string> >("variables")),
0033 spectators(params.getParameter<std::vector<std::string> >("spectators")),
0034 weightFile(params.existsAs<edm::FileInPath>("weightFile") ? params.getParameter<edm::FileInPath>("weightFile")
0035 : edm::FileInPath()),
0036 useGBRForest(params.existsAs<bool>("useGBRForest") ? params.getParameter<bool>("useGBRForest") : false),
0037 useAdaBoost(params.existsAs<bool>("useAdaBoost") ? params.getParameter<bool>("useAdaBoost") : false),
0038 tokens(std::move(tokens))
0039
0040 {
0041 uses(0, "ipTagInfos");
0042 uses(1, "svAVRTagInfos");
0043 uses(2, "svIVFTagInfos");
0044 uses(3, "smTagInfos");
0045 uses(4, "seTagInfos");
0046 }
0047
0048 CombinedMVAV2JetTagComputer::~CombinedMVAV2JetTagComputer() {}
0049
0050 void CombinedMVAV2JetTagComputer::initialize(const JetTagComputerRecord &record) {
0051 mvaID = std::make_unique<TMVAEvaluator>();
0052
0053 if (tokens.gbrForest_.isInitialized()) {
0054 mvaID->initializeGBRForest(&record.get(tokens.gbrForest_), variables, spectators, useAdaBoost);
0055 } else {
0056 mvaID->initialize(
0057 "Color:Silent:Error", mvaName, weightFile.fullPath(), variables, spectators, useGBRForest, useAdaBoost);
0058 }
0059 computers.reserve(tokens.computers_.size());
0060 for (const auto &token : tokens.computers_) {
0061 computers.push_back(&record.get(token));
0062 }
0063 }
0064
0065 float CombinedMVAV2JetTagComputer::discriminator(const JetTagComputer::TagInfoHelper &info) const {
0066
0067 float value = -10.;
0068
0069
0070 std::vector<const BaseTagInfo *> jpTagInfos({&info.getBase(0)});
0071
0072
0073 std::vector<const BaseTagInfo *> avrTagInfos({&info.getBase(0), &info.getBase(1)});
0074
0075
0076 std::vector<const BaseTagInfo *> ivfTagInfos({&info.getBase(0), &info.getBase(2)});
0077
0078
0079 std::vector<const BaseTagInfo *> smTagInfos({&info.getBase(3)});
0080
0081
0082 std::vector<const BaseTagInfo *> seTagInfos({&info.getBase(4)});
0083
0084 std::map<std::string, float> inputs;
0085 inputs["Jet_JP"] = (*(computers[0]))(TagInfoHelper(jpTagInfos));
0086 inputs["Jet_JBP"] = (*(computers[1]))(TagInfoHelper(jpTagInfos));
0087 inputs["Jet_CSV"] = (*(computers[2]))(TagInfoHelper(avrTagInfos));
0088 inputs["Jet_CSVIVF"] = (*(computers[2]))(TagInfoHelper(ivfTagInfos));
0089 inputs["Jet_SoftMu"] = (*(computers[3]))(TagInfoHelper(smTagInfos));
0090 inputs["Jet_SoftEl"] = (*(computers[4]))(TagInfoHelper(seTagInfos));
0091
0092 if (inputs["Jet_JP"] <= 0) {
0093 inputs["Jet_JP"] = 0;
0094 }
0095 if (inputs["Jet_JBP"] <= 0) {
0096 inputs["Jet_JBP"] = 0;
0097 }
0098 if (inputs["Jet_CSV"] <= 0) {
0099 inputs["Jet_CSV"] = 0;
0100 }
0101 if (inputs["Jet_CSVIVF"] <= 0) {
0102 inputs["Jet_CSVIVF"] = 0;
0103 }
0104 if (inputs["Jet_SoftMu"] <= 0) {
0105 inputs["Jet_SoftMu"] = 0;
0106 }
0107 if (inputs["Jet_SoftEl"] <= 0) {
0108 inputs["Jet_SoftEl"] = 0;
0109 }
0110
0111 if (inputs["Jet_CSV"] >= 1) {
0112 inputs["Jet_CSV"] = 1;
0113 }
0114 if (inputs["Jet_CSVIVF"] >= 1) {
0115 inputs["Jet_CSVIVF"] = 1;
0116 }
0117 if (inputs["Jet_SoftMu"] >= 1) {
0118 inputs["Jet_SoftMu"] = 1;
0119 }
0120 if (inputs["Jet_SoftEl"] >= 1) {
0121 inputs["Jet_SoftEl"] = 1;
0122 }
0123
0124
0125 value = mvaID->evaluate(inputs);
0126
0127
0128 return value;
0129 }