File indexing completed on 2024-08-02 05:16:40
0001
0002
0003
0004
0005
0006 #ifndef PHYSICSTOOLS_PYTORCH_TEST_TESTBASE_H
0007 #define PHYSICSTOOLS_PYTORCH_TEST_TESTBASE_H
0008
0009 #include <boost/filesystem.hpp>
0010 #include <filesystem>
0011 #include <cppunit/extensions/HelperMacros.h>
0012 #include <stdexcept>
0013
0014 class testBasePyTorch : public CppUnit::TestFixture {
0015 public:
0016 std::string dataPath_;
0017
0018 void setUp();
0019 void tearDown();
0020 std::string cmsswPath(std::string path);
0021
0022 virtual void test() = 0;
0023
0024 virtual std::string pyScript() const = 0;
0025 };
0026
0027 void testBasePyTorch::setUp() {
0028 dataPath_ =
0029 cmsswPath("/test/" + std::string(std::getenv("SCRAM_ARCH")) + "/" + boost::filesystem::unique_path().string());
0030
0031
0032 std::string testPath = cmsswPath("/src/PhysicsTools/PyTorch/test");
0033 std::string cmd = "apptainer exec -B " + cmsswPath("") +
0034 " /cvmfs/unpacked.cern.ch/registry.hub.docker.com/cmsml/cmsml:3.11 python " + testPath + "/" +
0035 pyScript() + " " + dataPath_;
0036 std::cout << "cmd: " << cmd << std::endl;
0037 std::array<char, 128> buffer;
0038 std::string result;
0039 std::shared_ptr<FILE> pipe(popen(cmd.c_str(), "r"), pclose);
0040 if (!pipe) {
0041 throw std::runtime_error("Failed to run apptainer to prepare the PyTorch test model: " + cmd);
0042 }
0043 while (!feof(pipe.get())) {
0044 if (fgets(buffer.data(), 128, pipe.get()) != NULL) {
0045 result += buffer.data();
0046 }
0047 }
0048 std::cout << std::endl << result << std::endl;
0049 }
0050
0051 void testBasePyTorch::tearDown() {
0052 if (std::filesystem::exists(dataPath_)) {
0053 std::filesystem::remove_all(dataPath_);
0054 }
0055 }
0056
0057 std::string testBasePyTorch::cmsswPath(std::string path) {
0058 if (path.size() > 0 && path.substr(0, 1) != "/") {
0059 path = "/" + path;
0060 }
0061
0062 std::string base = std::string(std::getenv("CMSSW_BASE"));
0063 std::string releaseBase = std::string(std::getenv("CMSSW_RELEASE_BASE"));
0064
0065 return (std::filesystem::exists(base.c_str()) ? base : releaseBase) + path;
0066 }
0067
0068 #endif