Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-08-02 05:16:40

0001 #include <torch/script.h>
0002 #include "testBaseCUDA.h"
0003 #include <iostream>
0004 #include <memory>
0005 #include <vector>
0006 #include "HeterogeneousCore/CUDAServices/interface/CUDAInterface.h"
0007 
0008 class testSimpleDNNCUDA : public testBasePyTorchCUDA {
0009   CPPUNIT_TEST_SUITE(testSimpleDNNCUDA);
0010   CPPUNIT_TEST(test);
0011   CPPUNIT_TEST_SUITE_END();
0012 
0013 public:
0014   std::string pyScript() const override;
0015   void test() override;
0016 };
0017 
0018 CPPUNIT_TEST_SUITE_REGISTRATION(testSimpleDNNCUDA);
0019 
0020 std::string testSimpleDNNCUDA::pyScript() const { return "create_simple_dnn.py"; }
0021 
0022 void testSimpleDNNCUDA::test() {
0023   std::vector<edm::ParameterSet> psets;
0024   edm::ServiceToken serviceToken = edm::ServiceRegistry::createSet(psets);
0025   edm::ServiceRegistry::Operate operate(serviceToken);
0026 
0027   // Setup the CUDA Service
0028   edmplugin::PluginManager::configure(edmplugin::standard::config());
0029 
0030   std::string const config = R"_(import FWCore.ParameterSet.Config as cms
0031 process = cms.Process('Test')
0032 process.add_(cms.Service('ResourceInformationService'))
0033 process.add_(cms.Service('CUDAService'))
0034 )_";
0035   std::unique_ptr<edm::ParameterSet> params;
0036   edm::makeParameterSets(config, params);
0037   edm::ServiceToken tempToken(edm::ServiceRegistry::createServicesFromConfig(std::move(params)));
0038   edm::ServiceRegistry::Operate operate2(tempToken);
0039   edm::Service<CUDAInterface> cuda;
0040   std::cout << "CUDA service enabled: " << cuda->enabled() << std::endl;
0041 
0042   std::cout << "Testing CUDA backend" << std::endl;
0043 
0044   std::string model_path = dataPath_ + "/simple_dnn.pt";
0045   torch::Device device(torch::kCUDA);
0046   torch::jit::script::Module module;
0047   try {
0048     // Deserialize the ScriptModule from a file using torch::jit::load().
0049     module = torch::jit::load(model_path);
0050     module.to(device);
0051   } catch (const c10::Error& e) {
0052     std::cerr << "error loading the model\n" << e.what() << std::endl;
0053     CPPUNIT_ASSERT(false);
0054   }
0055   // Create a vector of inputs.
0056   std::vector<torch::jit::IValue> inputs;
0057   inputs.push_back(torch::ones(10, device));
0058 
0059   // Execute the model and turn its output into a tensor.
0060   at::Tensor output = module.forward(inputs).toTensor();
0061   std::cout << "output: " << output << '\n';
0062   CPPUNIT_ASSERT(output.item<float_t>() == 110.);
0063   std::cout << "ok\n";
0064 }