File indexing completed on 2024-08-02 05:16:40
0001 #include <torch/script.h>
0002 #include "testBase.h"
0003 #include <iostream>
0004 #include <memory>
0005 #include <vector>
0006
0007 class testSimpleDNN : public testBasePyTorch {
0008 CPPUNIT_TEST_SUITE(testSimpleDNN);
0009 CPPUNIT_TEST(test);
0010 CPPUNIT_TEST_SUITE_END();
0011
0012 public:
0013 std::string pyScript() const override;
0014 void test() override;
0015 };
0016
0017 CPPUNIT_TEST_SUITE_REGISTRATION(testSimpleDNN);
0018
0019 std::string testSimpleDNN::pyScript() const { return "create_simple_dnn.py"; }
0020
0021 void testSimpleDNN::test() {
0022 std::string model_path = dataPath_ + "/simple_dnn.pt";
0023 torch::Device device(torch::kCPU);
0024 torch::jit::script::Module module;
0025 try {
0026
0027 module = torch::jit::load(model_path);
0028 module.to(device);
0029 } catch (const c10::Error& e) {
0030 std::cerr << "error loading the model\n" << e.what() << std::endl;
0031 CPPUNIT_ASSERT(false);
0032 }
0033
0034 std::vector<torch::jit::IValue> inputs;
0035 inputs.push_back(torch::ones(10, device));
0036
0037
0038 at::Tensor output = module.forward(inputs).toTensor();
0039 std::cout << "output: " << output << '\n';
0040 CPPUNIT_ASSERT(output.item<float_t>() == 110.);
0041 std::cout << "ok\n";
0042 }