File indexing completed on 2025-01-31 02:19:53
0001 #include <algorithm>
0002 #include <iostream>
0003 #include <memory>
0004 #include <string>
0005
0006 #include "FWCore/Utilities/interface/Exception.h"
0007 #include "FWCore/Framework/interface/EventSetup.h"
0008 #include "CondFormats/PhysicsToolsObjects/interface/MVAComputer.h"
0009 #include "DataFormats/Common/interface/RefToBase.h"
0010 #include "DataFormats/JetReco/interface/Jet.h"
0011 #include "DataFormats/BTauReco/interface/TaggingVariable.h"
0012 #include "RecoBTau/JetTagComputer/interface/GenericMVAComputer.h"
0013 #include "RecoBTau/JetTagComputer/interface/GenericMVAJetTagComputer.h"
0014 #include "RecoBTau/JetTagComputer/interface/JetTagComputerRecord.h"
0015
0016 using namespace reco;
0017 using namespace PhysicsTools;
0018
0019 static std::vector<std::string> getCalibrationLabels(const edm::ParameterSet ¶ms,
0020 std::unique_ptr<TagInfoMVACategorySelector> &selector) {
0021 if (params.getParameter<bool>("useCategories")) {
0022 selector = std::make_unique<TagInfoMVACategorySelector>(params);
0023
0024 return selector->getCategoryLabels();
0025 } else {
0026 std::string calibrationRecord = params.getParameter<std::string>("calibrationRecord");
0027
0028 std::vector<std::string> calibrationLabels;
0029 calibrationLabels.push_back(calibrationRecord);
0030 return calibrationLabels;
0031 }
0032 }
0033
0034 GenericMVAJetTagComputer::Tokens::Tokens(const edm::ParameterSet ¶ms, edm::ESConsumesCollector &&cc)
0035 : calib_(cc.consumes(edm::ESInputTag{"", params.getParameter<std::string>("recordLabel")})) {}
0036
0037 GenericMVAJetTagComputer::GenericMVAJetTagComputer(const edm::ParameterSet ¶ms, Tokens tokens)
0038 : computerCache_(getCalibrationLabels(params, categorySelector_)), tokens_{tokens} {}
0039
0040 GenericMVAJetTagComputer::~GenericMVAJetTagComputer() {}
0041
0042 void GenericMVAJetTagComputer::initialize(const JetTagComputerRecord &record) {
0043
0044 computerCache_.update(&record.get(tokens_.calib_));
0045 }
0046
0047 float GenericMVAJetTagComputer::discriminator(const TagInfoHelper &info) const {
0048 TaggingVariableList variables = taggingVariables(info);
0049
0050
0051 int index = 0;
0052 if (categorySelector_.get()) {
0053 index = categorySelector_->findCategory(variables);
0054 if (index < 0)
0055 return -10.0;
0056 }
0057
0058 GenericMVAComputer const *computer = computerCache_.getComputer(index);
0059
0060 if (!computer)
0061 return -10.0;
0062
0063 return computer->eval(variables);
0064 }
0065
0066 TaggingVariableList GenericMVAJetTagComputer::taggingVariables(const BaseTagInfo &baseTag) const {
0067 TaggingVariableList variables = baseTag.taggingVariables();
0068
0069
0070 edm::RefToBase<Jet> jet = baseTag.jet();
0071 variables.push_back(TaggingVariable(btau::jetPt, jet->pt()));
0072 variables.push_back(TaggingVariable(btau::jetEta, jet->eta()));
0073
0074 return variables;
0075 }
0076
0077 TaggingVariableList GenericMVAJetTagComputer::taggingVariables(const TagInfoHelper &info) const {
0078 return taggingVariables(info.getBase(0));
0079 }
0080
0081 void GenericMVAJetTagComputer::fillPSetDescription(edm::ParameterSetDescription &desc) {
0082 desc.add<bool>("useCategories", false);
0083 TagInfoMVACategorySelector::fillPSetDescription(desc);
0084 desc.add<std::string>("calibrationRecord", "");
0085 desc.add<std::string>("recordLabel", "");
0086 }