Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:23:37

0001 #include <algorithm>
0002 #include <iterator>
0003 #include <iostream>
0004 #include <fstream>
0005 #include <memory>
0006 #include <string>
0007 #include <vector>
0008 #include <cstddef>
0009 #include <cstring>
0010 
0011 #include <TString.h>
0012 
0013 #include "FWCore/Utilities/interface/Exception.h"
0014 
0015 #include "PhysicsTools/MVAComputer/interface/zstream.h"
0016 #include "PhysicsTools/MVAComputer/interface/memstream.h"
0017 
0018 #include "PhysicsTools/MVAComputer/interface/Calibration.h"
0019 #include "PhysicsTools/MVAComputer/interface/BitSet.h"
0020 #include "PhysicsTools/MVAComputer/interface/MVAComputer.h"
0021 
0022 using namespace PhysicsTools;
0023 
0024 static std::size_t getStreamSize(std::ifstream &in) {
0025   in.seekg(0, std::ios::end);
0026   std::ifstream::pos_type end = in.tellg();
0027   in.seekg(0, std::ios::beg);
0028   std::ifstream::pos_type begin = in.tellg();
0029 
0030   return (std::size_t)(end - begin);
0031 }
0032 
0033 static Calibration::VarProcessor *getCalibration(const std::string &file, const std::vector<std::string> &names) {
0034   std::unique_ptr<Calibration::ProcExternal> calib(new Calibration::ProcExternal);
0035 
0036   std::ifstream in(file.c_str(), std::ios::binary | std::ios::in);
0037   if (!in.good())
0038     throw cms::Exception("mvaWeightsToCalibration") << "Weights file \"" << file
0039                                                     << "\" "
0040                                                        "cannot be opened for reading."
0041                                                     << std::endl;
0042 
0043   char buf[512];
0044 
0045   while (in.good() && !TString(buf).BeginsWith("Method"))
0046     in.getline(buf, 512);
0047   if (!in.good())
0048     throw cms::Exception("mvaWeightsToCalibration") << "Weights file \"" << file
0049                                                     << "\" "
0050                                                        "is not a TMVA weights file."
0051                                                     << std::endl;
0052 
0053   TString ls(buf);
0054   Int_t idx1 = ls.First(':') + 2;
0055   Int_t idx2 = ls.Index(' ', idx1) - idx1;
0056   if (idx2 < 0)
0057     idx2 = ls.Length();
0058   TString fullname = ls(idx1, idx2);
0059   idx1 = fullname.First(':');
0060   Int_t idxtit = (idx1 < 0 ? fullname.Length() : idx1);
0061   TString methodName = fullname(0, idxtit);
0062 
0063   std::size_t size = getStreamSize(in) + methodName.Length();
0064   for (std::vector<std::string>::const_iterator iter = names.begin(); iter != names.end(); ++iter)
0065     size += iter->size() + 1;
0066   size += (size / 32) + 128;
0067 
0068   char *buffer = nullptr;
0069   try {
0070     buffer = new char[size];
0071     ext::omemstream os(buffer, size);
0072     /* call dtor of ozs at end */ {
0073       ext::ozstream ozs(&os);
0074       ozs << methodName << "\n";
0075       ozs << names.size() << "\n";
0076       for (std::vector<std::string>::const_iterator iter = names.begin(); iter != names.end(); ++iter)
0077         ozs << *iter << "\n";
0078       ozs << in.rdbuf();
0079       ozs.flush();
0080     }
0081     size = os.end() - os.begin();
0082     calib->store.resize(size);
0083     std::memcpy(&calib->store.front(), os.begin(), size);
0084   } catch (...) {
0085     delete[] buffer;
0086     throw;
0087   }
0088   delete[] buffer;
0089   in.close();
0090 
0091   calib->method = "ProcTMVA";
0092 
0093   return calib.release();
0094 }
0095 
0096 int main(int argc, char **argv) {
0097   if (argc < 4) {
0098     std::cerr << "Syntax: " << argv[0] << " <input> "
0099               << "<output> <var1> [<var2>...]" << std::endl;
0100     return 1;
0101   }
0102 
0103   std::vector<std::string> names;
0104   for (int i = 3; i < argc; i++)
0105     names.push_back(argv[i]);
0106 
0107   try {
0108     std::unique_ptr<Calibration::VarProcessor> proc(getCalibration(argv[1], names));
0109 
0110     BitSet inputVars(names.size());
0111     for (std::size_t i = 0; i < names.size(); i++)
0112       inputVars[i] = true;
0113     proc->inputVars = Calibration::convert(inputVars);
0114 
0115     Calibration::MVAComputer mva;
0116     std::copy(names.begin(), names.end(), std::back_inserter(mva.inputSet));
0117     mva.addProcessor(proc.get());
0118     mva.output = names.size();
0119 
0120     MVAComputer::writeCalibration(argv[2], &mva);
0121   } catch (cms::Exception const &e) {
0122     std::cerr << e.what() << std::endl;
0123   }
0124 
0125   return 0;
0126 }