File indexing completed on 2024-04-06 12:23:38
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 #include <memory>
0018 #include <fstream>
0019 #include <sstream>
0020 #include <string>
0021 #include <vector>
0022 #include <memory>
0023 #include <iostream>
0024 #include <cstdio>
0025
0026
0027 #include <RVersion.h>
0028
0029 #include <TMVA/Types.h>
0030 #include <TMVA/MethodBase.h>
0031 #include "TMVA/Reader.h"
0032
0033 #include "PhysicsTools/MVAComputer/interface/memstream.h"
0034 #include "PhysicsTools/MVAComputer/interface/zstream.h"
0035
0036 #include "PhysicsTools/MVAComputer/interface/VarProcessor.h"
0037 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
0038 #include "PhysicsTools/MVAComputer/interface/mva_computer_define_plugin.h"
0039 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0040
0041 #include <boost/filesystem.hpp>
0042
0043 using namespace PhysicsTools;
0044
0045 namespace {
0046
0047 class ProcTMVA : public VarProcessor {
0048 public:
0049 typedef VarProcessor::Registry::Registry<ProcTMVA, Calibration::ProcExternal> Registry;
0050
0051 ProcTMVA(const char *name, const Calibration::ProcExternal *calib, const MVAComputer *computer);
0052 ~ProcTMVA() override {}
0053
0054 void configure(ConfIterator iter, unsigned int n) override;
0055 void eval(ValueIterator iter, unsigned int n) const override;
0056
0057 private:
0058 std::unique_ptr<TMVA::Reader> reader;
0059 TMVA::MethodBase *method;
0060 std::string methodName;
0061 unsigned int nVars;
0062
0063
0064 TString methodName_t;
0065 };
0066
0067 ProcTMVA::Registry registry("ProcTMVA");
0068
0069 ProcTMVA::ProcTMVA(const char *name, const Calibration::ProcExternal *calib, const MVAComputer *computer)
0070 : VarProcessor(name, calib, computer) {
0071 reader = std::make_unique<TMVA::Reader>("!Color:Silent");
0072
0073 ext::imemstream is(reinterpret_cast<const char *>(&calib->store.front()), calib->store.size());
0074 ext::izstream izs(&is);
0075
0076 std::getline(izs, methodName);
0077
0078 std::string tmp;
0079 std::getline(izs, tmp);
0080 std::istringstream iss(tmp);
0081 iss >> nVars;
0082 for (unsigned int i = 0; i < nVars; i++) {
0083 std::getline(izs, tmp);
0084 reader->DataInfo().AddVariable(tmp.c_str());
0085 }
0086
0087
0088 std::string weight_text;
0089 std::string line;
0090 while (std::getline(izs, line)) {
0091 weight_text += line;
0092 weight_text += "\n";
0093 }
0094
0095
0096 TMVA::Types::EMVA methodType = TMVA::Types::Instance().GetMethodType(methodName);
0097
0098 if (weight_text.find("<?xml") != std::string::npos) {
0099 method = dynamic_cast<TMVA::MethodBase *>(reader->BookMVA(methodType, weight_text.c_str()));
0100 } else {
0101
0102 TString weight_file_name(boost::filesystem::unique_path().c_str());
0103 std::ofstream weight_file;
0104 weight_file.open(weight_file_name.Data());
0105 weight_file << weight_text;
0106 weight_file.close();
0107 edm::LogInfo("LegacyMVA") << "Building legacy TMVA plugin - "
0108 << "the weights are being stored in " << weight_file_name << std::endl;
0109 methodName_t.Append(methodName.c_str());
0110 method = dynamic_cast<TMVA::MethodBase *>(reader->BookMVA(methodName_t, weight_file_name));
0111 remove(weight_file_name.Data());
0112 }
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147
0148
0149
0150
0151
0152
0153
0154
0155
0156
0157
0158
0159
0160
0161
0162
0163
0164
0165
0166
0167
0168
0169
0170
0171
0172
0173
0174
0175
0176
0177 }
0178
0179 void ProcTMVA::configure(ConfIterator iter, unsigned int n) {
0180 if (n != nVars)
0181 return;
0182
0183 for (unsigned int i = 0; i < n; i++)
0184 iter++(Variable::FLAG_NONE);
0185
0186 iter << Variable::FLAG_NONE;
0187 }
0188
0189 void ProcTMVA::eval(ValueIterator iter, unsigned int n) const {
0190 std::vector<Float_t> inputs;
0191 inputs.reserve(n);
0192 for (unsigned int i = 0; i < n; i++)
0193 inputs.push_back(*iter++);
0194 std::unique_ptr<TMVA::Event> evt(new TMVA::Event(inputs, 2));
0195
0196 double result = method->GetMvaValue(evt.get());
0197
0198 iter(result);
0199 }
0200
0201 }
0202 MVA_COMPUTER_DEFINE_PLUGIN(ProcTMVA);