Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:24

0001 #include <algorithm>
0002 #include <iostream>
0003 #include <sstream>
0004 #include <string>
0005 #include <memory>
0006 #include <vector>
0007 #include <map>
0008 
0009 #include "FWCore/Utilities/interface/Exception.h"
0010 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0011 #include "DataFormats/Common/interface/RefToBase.h"
0012 #include "DataFormats/JetReco/interface/Jet.h"
0013 #include "RecoBTau/JetTagComputer/interface/JetTagComputerRecord.h"
0014 #include "RecoBTag/Combined/interface/CombinedMVAV2JetTagComputer.h"
0015 
0016 using namespace reco;
0017 
0018 CombinedMVAV2JetTagComputer::Tokens::Tokens(const edm::ParameterSet &params, edm::ESConsumesCollector &&cc) {
0019   if (params.getParameter<bool>("useCondDB")) {
0020     gbrForest_ = cc.consumes(edm::ESInputTag{
0021         "", params.existsAs<std::string>("gbrForestLabel") ? params.getParameter<std::string>("gbrForestLabel") : ""});
0022   }
0023   const auto &inputComputerNames = params.getParameter<std::vector<std::string> >("jetTagComputers");
0024   computers_.resize(inputComputerNames.size());
0025   for (size_t i = 0; i < inputComputerNames.size(); ++i) {
0026     computers_[i] = cc.consumes(edm::ESInputTag{"", inputComputerNames[i]});
0027   }
0028 }
0029 
0030 CombinedMVAV2JetTagComputer::CombinedMVAV2JetTagComputer(const edm::ParameterSet &params, Tokens tokens)
0031     : mvaName(params.getParameter<std::string>("mvaName")),
0032       variables(params.getParameter<std::vector<std::string> >("variables")),
0033       spectators(params.getParameter<std::vector<std::string> >("spectators")),
0034       weightFile(params.existsAs<edm::FileInPath>("weightFile") ? params.getParameter<edm::FileInPath>("weightFile")
0035                                                                 : edm::FileInPath()),
0036       useGBRForest(params.existsAs<bool>("useGBRForest") ? params.getParameter<bool>("useGBRForest") : false),
0037       useAdaBoost(params.existsAs<bool>("useAdaBoost") ? params.getParameter<bool>("useAdaBoost") : false),
0038       tokens(std::move(tokens))
0039 
0040 {
0041   uses(0, "ipTagInfos");
0042   uses(1, "svAVRTagInfos");
0043   uses(2, "svIVFTagInfos");
0044   uses(3, "smTagInfos");
0045   uses(4, "seTagInfos");
0046 }
0047 
0048 CombinedMVAV2JetTagComputer::~CombinedMVAV2JetTagComputer() {}
0049 
0050 void CombinedMVAV2JetTagComputer::initialize(const JetTagComputerRecord &record) {
0051   mvaID = std::make_unique<TMVAEvaluator>();
0052 
0053   if (tokens.gbrForest_.isInitialized()) {
0054     mvaID->initializeGBRForest(&record.get(tokens.gbrForest_), variables, spectators, useAdaBoost);
0055   } else {
0056     mvaID->initialize(
0057         "Color:Silent:Error", mvaName, weightFile.fullPath(), variables, spectators, useGBRForest, useAdaBoost);
0058   }
0059   computers.reserve(tokens.computers_.size());
0060   for (const auto &token : tokens.computers_) {
0061     computers.push_back(&record.get(token));
0062   }
0063 }
0064 
0065 float CombinedMVAV2JetTagComputer::discriminator(const JetTagComputer::TagInfoHelper &info) const {
0066   // default discriminator value
0067   float value = -10.;
0068 
0069   // TagInfos for JP taggers
0070   std::vector<const BaseTagInfo *> jpTagInfos({&info.getBase(0)});
0071 
0072   // TagInfos for the CSVv2AVR tagger
0073   std::vector<const BaseTagInfo *> avrTagInfos({&info.getBase(0), &info.getBase(1)});
0074 
0075   // TagInfos for the CSVv2IVF tagger
0076   std::vector<const BaseTagInfo *> ivfTagInfos({&info.getBase(0), &info.getBase(2)});
0077 
0078   // TagInfos for the SoftMuon tagger
0079   std::vector<const BaseTagInfo *> smTagInfos({&info.getBase(3)});
0080 
0081   // TagInfos for the SoftElectron tagger
0082   std::vector<const BaseTagInfo *> seTagInfos({&info.getBase(4)});
0083 
0084   std::map<std::string, float> inputs;
0085   inputs["Jet_JP"] = (*(computers[0]))(TagInfoHelper(jpTagInfos));
0086   inputs["Jet_JBP"] = (*(computers[1]))(TagInfoHelper(jpTagInfos));
0087   inputs["Jet_CSV"] = (*(computers[2]))(TagInfoHelper(avrTagInfos));
0088   inputs["Jet_CSVIVF"] = (*(computers[2]))(TagInfoHelper(ivfTagInfos));
0089   inputs["Jet_SoftMu"] = (*(computers[3]))(TagInfoHelper(smTagInfos));
0090   inputs["Jet_SoftEl"] = (*(computers[4]))(TagInfoHelper(seTagInfos));
0091 
0092   if (inputs["Jet_JP"] <= 0) {
0093     inputs["Jet_JP"] = 0;
0094   }
0095   if (inputs["Jet_JBP"] <= 0) {
0096     inputs["Jet_JBP"] = 0;
0097   }
0098   if (inputs["Jet_CSV"] <= 0) {
0099     inputs["Jet_CSV"] = 0;
0100   }
0101   if (inputs["Jet_CSVIVF"] <= 0) {
0102     inputs["Jet_CSVIVF"] = 0;
0103   }
0104   if (inputs["Jet_SoftMu"] <= 0) {
0105     inputs["Jet_SoftMu"] = 0;
0106   }
0107   if (inputs["Jet_SoftEl"] <= 0) {
0108     inputs["Jet_SoftEl"] = 0;
0109   }
0110 
0111   if (inputs["Jet_CSV"] >= 1) {
0112     inputs["Jet_CSV"] = 1;
0113   }
0114   if (inputs["Jet_CSVIVF"] >= 1) {
0115     inputs["Jet_CSVIVF"] = 1;
0116   }
0117   if (inputs["Jet_SoftMu"] >= 1) {
0118     inputs["Jet_SoftMu"] = 1;
0119   }
0120   if (inputs["Jet_SoftEl"] >= 1) {
0121     inputs["Jet_SoftEl"] = 1;
0122   }
0123 
0124   // evaluate the MVA
0125   value = mvaID->evaluate(inputs);
0126 
0127   // return the final discriminator value
0128   return value;
0129 }