Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-01-31 02:19:51

0001 #include <algorithm>
0002 #include <iostream>
0003 #include <sstream>
0004 #include <string>
0005 #include <memory>
0006 #include <vector>
0007 #include <map>
0008 
0009 #include "FWCore/Utilities/interface/Exception.h"
0010 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0011 #include "DataFormats/Common/interface/RefToBase.h"
0012 #include "DataFormats/JetReco/interface/Jet.h"
0013 #include "RecoBTau/JetTagComputer/interface/JetTagComputerRecord.h"
0014 #include "RecoBTag/Combined/interface/CombinedMVAV2JetTagComputer.h"
0015 
0016 using namespace reco;
0017 
0018 CombinedMVAV2JetTagComputer::Tokens::Tokens(const edm::ParameterSet &params, edm::ESConsumesCollector &&cc) {
0019   if (params.getParameter<bool>("useCondDB")) {
0020     gbrForest_ = cc.consumes(edm::ESInputTag{"", params.getParameter<std::string>("gbrForestLabel")});
0021   }
0022   const auto &inputComputerNames = params.getParameter<std::vector<std::string> >("jetTagComputers");
0023   computers_.resize(inputComputerNames.size());
0024   for (size_t i = 0; i < inputComputerNames.size(); ++i) {
0025     computers_[i] = cc.consumes(edm::ESInputTag{"", inputComputerNames[i]});
0026   }
0027 }
0028 
0029 CombinedMVAV2JetTagComputer::CombinedMVAV2JetTagComputer(const edm::ParameterSet &params, Tokens tokens)
0030     : mvaName(params.getParameter<std::string>("mvaName")),
0031       variables(params.getParameter<std::vector<std::string> >("variables")),
0032       spectators(params.getParameter<std::vector<std::string> >("spectators")),
0033       weightFile(params.getParameter<edm::FileInPath>("weightFile")),
0034       useGBRForest(params.getParameter<bool>("useGBRForest")),
0035       useAdaBoost(params.getParameter<bool>("useAdaBoost")),
0036       tokens(std::move(tokens))
0037 
0038 {
0039   uses(0, "ipTagInfos");
0040   uses(1, "svAVRTagInfos");
0041   uses(2, "svIVFTagInfos");
0042   uses(3, "smTagInfos");
0043   uses(4, "seTagInfos");
0044 }
0045 
0046 CombinedMVAV2JetTagComputer::~CombinedMVAV2JetTagComputer() {}
0047 
0048 void CombinedMVAV2JetTagComputer::initialize(const JetTagComputerRecord &record) {
0049   mvaID = std::make_unique<TMVAEvaluator>();
0050 
0051   if (tokens.gbrForest_.isInitialized()) {
0052     mvaID->initializeGBRForest(&record.get(tokens.gbrForest_), variables, spectators, useAdaBoost);
0053   } else {
0054     mvaID->initialize(
0055         "Color:Silent:Error", mvaName, weightFile.fullPath(), variables, spectators, useGBRForest, useAdaBoost);
0056   }
0057   computers.reserve(tokens.computers_.size());
0058   for (const auto &token : tokens.computers_) {
0059     computers.push_back(&record.get(token));
0060   }
0061 }
0062 
0063 float CombinedMVAV2JetTagComputer::discriminator(const JetTagComputer::TagInfoHelper &info) const {
0064   // default discriminator value
0065   float value = -10.;
0066 
0067   // TagInfos for JP taggers
0068   std::vector<const BaseTagInfo *> jpTagInfos({&info.getBase(0)});
0069 
0070   // TagInfos for the CSVv2AVR tagger
0071   std::vector<const BaseTagInfo *> avrTagInfos({&info.getBase(0), &info.getBase(1)});
0072 
0073   // TagInfos for the CSVv2IVF tagger
0074   std::vector<const BaseTagInfo *> ivfTagInfos({&info.getBase(0), &info.getBase(2)});
0075 
0076   // TagInfos for the SoftMuon tagger
0077   std::vector<const BaseTagInfo *> smTagInfos({&info.getBase(3)});
0078 
0079   // TagInfos for the SoftElectron tagger
0080   std::vector<const BaseTagInfo *> seTagInfos({&info.getBase(4)});
0081 
0082   std::map<std::string, float> inputs;
0083   inputs["Jet_JP"] = (*(computers[0]))(TagInfoHelper(jpTagInfos));
0084   inputs["Jet_JBP"] = (*(computers[1]))(TagInfoHelper(jpTagInfos));
0085   inputs["Jet_CSV"] = (*(computers[2]))(TagInfoHelper(avrTagInfos));
0086   inputs["Jet_CSVIVF"] = (*(computers[2]))(TagInfoHelper(ivfTagInfos));
0087   inputs["Jet_SoftMu"] = (*(computers[3]))(TagInfoHelper(smTagInfos));
0088   inputs["Jet_SoftEl"] = (*(computers[4]))(TagInfoHelper(seTagInfos));
0089 
0090   if (inputs["Jet_JP"] <= 0) {
0091     inputs["Jet_JP"] = 0;
0092   }
0093   if (inputs["Jet_JBP"] <= 0) {
0094     inputs["Jet_JBP"] = 0;
0095   }
0096   if (inputs["Jet_CSV"] <= 0) {
0097     inputs["Jet_CSV"] = 0;
0098   }
0099   if (inputs["Jet_CSVIVF"] <= 0) {
0100     inputs["Jet_CSVIVF"] = 0;
0101   }
0102   if (inputs["Jet_SoftMu"] <= 0) {
0103     inputs["Jet_SoftMu"] = 0;
0104   }
0105   if (inputs["Jet_SoftEl"] <= 0) {
0106     inputs["Jet_SoftEl"] = 0;
0107   }
0108 
0109   if (inputs["Jet_CSV"] >= 1) {
0110     inputs["Jet_CSV"] = 1;
0111   }
0112   if (inputs["Jet_CSVIVF"] >= 1) {
0113     inputs["Jet_CSVIVF"] = 1;
0114   }
0115   if (inputs["Jet_SoftMu"] >= 1) {
0116     inputs["Jet_SoftMu"] = 1;
0117   }
0118   if (inputs["Jet_SoftEl"] >= 1) {
0119     inputs["Jet_SoftEl"] = 1;
0120   }
0121 
0122   // evaluate the MVA
0123   value = mvaID->evaluate(inputs);
0124 
0125   // return the final discriminator value
0126   return value;
0127 }
0128 
0129 void CombinedMVAV2JetTagComputer::fillPSetDescription(edm::ParameterSetDescription &desc) {
0130   desc.add<bool>("useCondDB", false);
0131   desc.add<std::string>("gbrForestLabel", "");
0132   desc.add<std::vector<std::string> >("jetTagComputers", {});
0133   desc.add<std::string>("mvaName", "");
0134   desc.add<std::vector<std::string> >("variables", {});
0135   desc.add<std::vector<std::string> >("spectators", {});
0136   desc.add<edm::FileInPath>("weightFile", edm::FileInPath());
0137   desc.add<bool>("useGBRForest", false);
0138   desc.add<bool>("useAdaBoost", false);
0139 }