Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:23:38

0001 // -*- C++ -*-
0002 //
0003 // Package:     MVAComputer
0004 // Class  :     ProcTMVA
0005 //
0006 
0007 // Implementation:
0008 //     TMVA wrapper, needs n non-optional, non-multiple input variables
0009 //     and outputs one result variable. All TMVA algorithms can be used,
0010 //     calibration data is passed via stream and extracted from a zipped
0011 //     buffer.
0012 //
0013 // Author:      Christophe Saout
0014 // Created:     Sat Apr 24 15:18 CEST 2007
0015 //
0016 
0017 #include <memory>
0018 #include <fstream>
0019 #include <sstream>
0020 #include <string>
0021 #include <vector>
0022 #include <memory>
0023 #include <iostream>
0024 #include <cstdio>
0025 
0026 // ROOT version magic to support TMVA interface changes in newer ROOT
0027 #include <RVersion.h>
0028 
0029 #include <TMVA/Types.h>
0030 #include <TMVA/MethodBase.h>
0031 #include "TMVA/Reader.h"
0032 
0033 #include "PhysicsTools/MVAComputer/interface/memstream.h"
0034 #include "PhysicsTools/MVAComputer/interface/zstream.h"
0035 
0036 #include "PhysicsTools/MVAComputer/interface/VarProcessor.h"
0037 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
0038 #include "PhysicsTools/MVAComputer/interface/mva_computer_define_plugin.h"
0039 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0040 
0041 #include <boost/filesystem.hpp>
0042 
0043 using namespace PhysicsTools;
0044 
0045 namespace {  // anonymous
0046 
0047   class ProcTMVA : public VarProcessor {
0048   public:
0049     typedef VarProcessor::Registry::Registry<ProcTMVA, Calibration::ProcExternal> Registry;
0050 
0051     ProcTMVA(const char *name, const Calibration::ProcExternal *calib, const MVAComputer *computer);
0052     ~ProcTMVA() override {}
0053 
0054     void configure(ConfIterator iter, unsigned int n) override;
0055     void eval(ValueIterator iter, unsigned int n) const override;
0056 
0057   private:
0058     std::unique_ptr<TMVA::Reader> reader;
0059     TMVA::MethodBase *method;
0060     std::string methodName;
0061     unsigned int nVars;
0062 
0063     // FIXME: Gena
0064     TString methodName_t;
0065   };
0066 
0067   ProcTMVA::Registry registry("ProcTMVA");
0068 
0069   ProcTMVA::ProcTMVA(const char *name, const Calibration::ProcExternal *calib, const MVAComputer *computer)
0070       : VarProcessor(name, calib, computer) {
0071     reader = std::make_unique<TMVA::Reader>("!Color:Silent");
0072 
0073     ext::imemstream is(reinterpret_cast<const char *>(&calib->store.front()), calib->store.size());
0074     ext::izstream izs(&is);
0075 
0076     std::getline(izs, methodName);
0077 
0078     std::string tmp;
0079     std::getline(izs, tmp);
0080     std::istringstream iss(tmp);
0081     iss >> nVars;
0082     for (unsigned int i = 0; i < nVars; i++) {
0083       std::getline(izs, tmp);
0084       reader->DataInfo().AddVariable(tmp.c_str());
0085     }
0086 
0087     // The rest of the gzip blob is the weights file
0088     std::string weight_text;
0089     std::string line;
0090     while (std::getline(izs, line)) {
0091       weight_text += line;
0092       weight_text += "\n";
0093     }
0094 
0095     // Build our reader
0096     TMVA::Types::EMVA methodType = TMVA::Types::Instance().GetMethodType(methodName);
0097     // Check if xml format
0098     if (weight_text.find("<?xml") != std::string::npos) {
0099       method = dynamic_cast<TMVA::MethodBase *>(reader->BookMVA(methodType, weight_text.c_str()));
0100     } else {
0101       // Write to a temporary file
0102       TString weight_file_name(boost::filesystem::unique_path().c_str());
0103       std::ofstream weight_file;
0104       weight_file.open(weight_file_name.Data());
0105       weight_file << weight_text;
0106       weight_file.close();
0107       edm::LogInfo("LegacyMVA") << "Building legacy TMVA plugin - "
0108                                 << "the weights are being stored in " << weight_file_name << std::endl;
0109       methodName_t.Append(methodName.c_str());
0110       method = dynamic_cast<TMVA::MethodBase *>(reader->BookMVA(methodName_t, weight_file_name));
0111       remove(weight_file_name.Data());
0112     }
0113 
0114     /*
0115   bool isXml = false; // weights in XML (TMVA 4) or plain text
0116   bool isFirstPass = true;
0117   TString weight_file_name(tmpnam(0));
0118   std:: ofstream weight_file;
0119   //
0120 
0121   std::string weights;
0122   while (izs.good()) {
0123     std::string tmp;
0124 
0125     if (isFirstPass){
0126       std::getline(izs, tmp);
0127       isFirstPass = false;
0128       if ( tmp.find("<?xml") != std::string::npos ){ //xml
0129     isXml = true;
0130     weights += tmp + " "; 
0131       }
0132       else{
0133     std::cout << std::endl;
0134     std::cout << "ProcTMVA::ProcTMVA(): *** WARNING! ***" << std::endl;
0135     std::cout << "ProcTMVA::ProcTMVA(): Old pre-TMVA 4 plain text weights are being loaded" << std::endl;
0136     std::cout << "ProcTMVA::ProcTMVA(): It may work but backwards compatibility is not guaranteed" << std::endl;
0137     std::cout << "ProcTMVA::ProcTMVA(): TMVA 4 weight file format is XML" << std::endl;
0138     std::cout << "ProcTMVA::ProcTMVA(): Retrain your networks as soon as possible!" << std::endl;
0139     std::cout << "ProcTMVA::ProcTMVA(): Creating temporary weight file " << weight_file_name << std::endl;
0140     weight_file.open(weight_file_name.Data());
0141     weight_file << tmp << std::endl;
0142       }
0143     } // end first pass
0144     else{
0145       if (isXml){ // xml
0146     izs >> tmp;
0147     weights += tmp + " "; 
0148       }
0149       else{       // plain text
0150     weight_file << tmp << std::endl;
0151       }
0152     } // end not first pass
0153     
0154   }
0155   if (weight_file.is_open()){
0156     std::cout << "ProcTMVA::ProcTMVA(): Deleting temporary weight file " << weight_file_name << std::endl;
0157     weight_file.close();
0158   }
0159 
0160   TMVA::Types::EMVA methodType =
0161               TMVA::Types::Instance().GetMethodType(methodName);
0162 
0163  if (isXml){
0164    method = std::unique_ptr<TMVA::MethodBase>
0165      ( dynamic_cast<TMVA::MethodBase*>
0166        ( reader->BookMVA( methodType, weights.c_str() ) ) ); 
0167  }
0168  else{
0169    methodName_t.Clear();
0170    methodName_t.Append(methodName.c_str());
0171    method = std::unique_ptr<TMVA::MethodBase>
0172      ( dynamic_cast<TMVA::MethodBase*>
0173        ( reader->BookMVA( methodName_t, weight_file_name ) ) );
0174  }
0175 
0176   */
0177   }
0178 
0179   void ProcTMVA::configure(ConfIterator iter, unsigned int n) {
0180     if (n != nVars)
0181       return;
0182 
0183     for (unsigned int i = 0; i < n; i++)
0184       iter++(Variable::FLAG_NONE);
0185 
0186     iter << Variable::FLAG_NONE;
0187   }
0188 
0189   void ProcTMVA::eval(ValueIterator iter, unsigned int n) const {
0190     std::vector<Float_t> inputs;
0191     inputs.reserve(n);
0192     for (unsigned int i = 0; i < n; i++)
0193       inputs.push_back(*iter++);
0194     std::unique_ptr<TMVA::Event> evt(new TMVA::Event(inputs, 2));
0195 
0196     double result = method->GetMvaValue(evt.get());
0197 
0198     iter(result);
0199   }
0200 
0201 }  // anonymous namespace
0202 MVA_COMPUTER_DEFINE_PLUGIN(ProcTMVA);