Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:24

0001 #include "RecoBTag/Combined/interface/HeavyIonCSVTagger.h"
0002 #include "DataFormats/BTauReco/interface/CandIPTagInfo.h"
0003 #include "DataFormats/BTauReco/interface/SecondaryVertexTagInfo.h"
0004 #include "FWCore/Utilities/interface/ESInputTag.h"
0005 #include "CondFormats/DataRecord/interface/GBRWrapperRcd.h"
0006 #include <memory>
0007 #include <algorithm>
0008 #include <map>
0009 #include <vector>
0010 
0011 HeavyIonCSVTagger::Tokens::Tokens(const edm::ParameterSet &configuration, edm::ESConsumesCollector &&cc) {
0012   if (configuration.getParameter<bool>("useCondDB")) {
0013     gbrForest_ = cc.consumes(edm::ESInputTag{"",
0014                                              configuration.existsAs<std::string>("gbrForestLabel")
0015                                                  ? configuration.getParameter<std::string>("gbrForestLabel")
0016                                                  : ""});
0017   }
0018 }
0019 HeavyIonCSVTagger::HeavyIonCSVTagger(const edm::ParameterSet &configuration, Tokens tokens)
0020     : sv_computer_(configuration.getParameter<edm::ParameterSet>("sv_cfg")),
0021       mva_name_(configuration.getParameter<std::string>("mvaName")),
0022       weight_file_(configuration.getParameter<edm::FileInPath>("weightFile")),
0023       use_GBRForest_(configuration.getParameter<bool>("useGBRForest")),
0024       use_adaBoost_(configuration.getParameter<bool>("useAdaBoost")),
0025       tokens_{tokens} {
0026   vpset vars_definition = configuration.getParameter<vpset>("variables");
0027 
0028   for (auto &var : vars_definition) {
0029     MVAVar mva_var;
0030     mva_var.name = var.getParameter<std::string>("name");
0031     mva_var.id = reco::getTaggingVariableName(var.getParameter<std::string>("taggingVarName"));
0032 
0033     mva_var.has_index = var.existsAs<int>("idx");
0034     mva_var.index = mva_var.has_index ? var.getParameter<int>("idx") : 0;
0035     mva_var.default_value = var.getParameter<double>("default");
0036 
0037     variables_.push_back(mva_var);
0038   }
0039 
0040   uses(0, "impactParameterTagInfos");
0041   uses(1, "secondaryVertexTagInfos");
0042 }
0043 
0044 void HeavyIonCSVTagger::initialize(const JetTagComputerRecord &record) {
0045   mvaID_ = std::make_unique<TMVAEvaluator>();
0046 
0047   std::vector<std::string> variable_names;
0048   variable_names.reserve(variables_.size());
0049 
0050   for (auto &var : variables_) {
0051     variable_names.push_back(var.name);
0052   }
0053   std::vector<std::string> spectators;
0054 
0055   if (tokens_.gbrForest_.isInitialized()) {
0056     mvaID_->initializeGBRForest(&record.get(tokens_.gbrForest_), variable_names, spectators, use_adaBoost_);
0057   } else {
0058     mvaID_->initialize("Color:Silent:Error",
0059                        mva_name_,
0060                        weight_file_.fullPath(),
0061                        variable_names,
0062                        spectators,
0063                        use_GBRForest_,
0064                        use_adaBoost_);
0065   }
0066 }
0067 
0068 HeavyIonCSVTagger::~HeavyIonCSVTagger() {}
0069 
0070 /// b-tag a jet based on track-to-jet parameters in the extened info collection
0071 float HeavyIonCSVTagger::discriminator(const TagInfoHelper &tagInfo) const {
0072   // default value, used if there are no leptons associated to this jet
0073   const reco::TrackIPTagInfo &ip_info = tagInfo.get<reco::TrackIPTagInfo>(0);
0074   const reco::SecondaryVertexTagInfo &sv_info = tagInfo.get<reco::SecondaryVertexTagInfo>(1);
0075   reco::TaggingVariableList vars = sv_computer_(ip_info, sv_info);
0076 
0077   // Loop over input variables
0078   std::map<std::string, float> inputs;
0079   std::vector<float> tagValList = vars.getList(reco::btau::trackSip3dSig, false);
0080   bool noTrack = (tagValList.empty());
0081 
0082   for (auto &mva_var : variables_) {
0083     //vectorial tagging variable
0084     if (mva_var.has_index) {
0085       std::vector<float> vals = vars.getList(mva_var.id, false);
0086       inputs[mva_var.name] = (vals.size() > mva_var.index) ? vals[mva_var.index] : mva_var.default_value;
0087     }
0088     //single value tagging var
0089     else {
0090       inputs[mva_var.name] = vars.get(mva_var.id, mva_var.default_value);
0091       if (noTrack) {
0092         if (mva_var.name == "TagVarCSV_vertexMass") {
0093           if (inputs[mva_var.name] < 0)
0094             return -1;
0095           noTrack = false;
0096         }
0097       }
0098     }
0099   }
0100 
0101   return (mvaID_->evaluate(inputs) + 1.) / 2.;
0102 }