Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-08-02 05:16:40

0001 #include <torch/script.h>
0002 #include "testBase.h"
0003 #include <iostream>
0004 #include <memory>
0005 #include <vector>
0006 
0007 class testSimpleDNN : public testBasePyTorch {
0008   CPPUNIT_TEST_SUITE(testSimpleDNN);
0009   CPPUNIT_TEST(test);
0010   CPPUNIT_TEST_SUITE_END();
0011 
0012 public:
0013   std::string pyScript() const override;
0014   void test() override;
0015 };
0016 
0017 CPPUNIT_TEST_SUITE_REGISTRATION(testSimpleDNN);
0018 
0019 std::string testSimpleDNN::pyScript() const { return "create_simple_dnn.py"; }
0020 
0021 void testSimpleDNN::test() {
0022   std::string model_path = dataPath_ + "/simple_dnn.pt";
0023   torch::Device device(torch::kCPU);
0024   torch::jit::script::Module module;
0025   try {
0026     // Deserialize the ScriptModule from a file using torch::jit::load().
0027     module = torch::jit::load(model_path);
0028     module.to(device);
0029   } catch (const c10::Error& e) {
0030     std::cerr << "error loading the model\n" << e.what() << std::endl;
0031     CPPUNIT_ASSERT(false);
0032   }
0033   // Create a vector of inputs.
0034   std::vector<torch::jit::IValue> inputs;
0035   inputs.push_back(torch::ones(10, device));
0036 
0037   // Execute the model and turn its output into a tensor.
0038   at::Tensor output = module.forward(inputs).toTensor();
0039   std::cout << "output: " << output << '\n';
0040   CPPUNIT_ASSERT(output.item<float_t>() == 110.);
0041   std::cout << "ok\n";
0042 }