File indexing completed on 2024-08-02 05:16:40
0001 #include <torch/script.h>
0002 #include "testBaseCUDA.h"
0003 #include <iostream>
0004 #include <memory>
0005 #include <vector>
0006 #include "HeterogeneousCore/CUDAServices/interface/CUDAInterface.h"
0007
0008 class testSimpleDNNCUDA : public testBasePyTorchCUDA {
0009 CPPUNIT_TEST_SUITE(testSimpleDNNCUDA);
0010 CPPUNIT_TEST(test);
0011 CPPUNIT_TEST_SUITE_END();
0012
0013 public:
0014 std::string pyScript() const override;
0015 void test() override;
0016 };
0017
0018 CPPUNIT_TEST_SUITE_REGISTRATION(testSimpleDNNCUDA);
0019
0020 std::string testSimpleDNNCUDA::pyScript() const { return "create_simple_dnn.py"; }
0021
0022 void testSimpleDNNCUDA::test() {
0023 std::vector<edm::ParameterSet> psets;
0024 edm::ServiceToken serviceToken = edm::ServiceRegistry::createSet(psets);
0025 edm::ServiceRegistry::Operate operate(serviceToken);
0026
0027
0028 edmplugin::PluginManager::configure(edmplugin::standard::config());
0029
0030 std::string const config = R"_(import FWCore.ParameterSet.Config as cms
0031 process = cms.Process('Test')
0032 process.add_(cms.Service('ResourceInformationService'))
0033 process.add_(cms.Service('CUDAService'))
0034 )_";
0035 std::unique_ptr<edm::ParameterSet> params;
0036 edm::makeParameterSets(config, params);
0037 edm::ServiceToken tempToken(edm::ServiceRegistry::createServicesFromConfig(std::move(params)));
0038 edm::ServiceRegistry::Operate operate2(tempToken);
0039 edm::Service<CUDAInterface> cuda;
0040 std::cout << "CUDA service enabled: " << cuda->enabled() << std::endl;
0041
0042 std::cout << "Testing CUDA backend" << std::endl;
0043
0044 std::string model_path = dataPath_ + "/simple_dnn.pt";
0045 torch::Device device(torch::kCUDA);
0046 torch::jit::script::Module module;
0047 try {
0048
0049 module = torch::jit::load(model_path);
0050 module.to(device);
0051 } catch (const c10::Error& e) {
0052 std::cerr << "error loading the model\n" << e.what() << std::endl;
0053 CPPUNIT_ASSERT(false);
0054 }
0055
0056 std::vector<torch::jit::IValue> inputs;
0057 inputs.push_back(torch::ones(10, device));
0058
0059
0060 at::Tensor output = module.forward(inputs).toTensor();
0061 std::cout << "output: " << output << '\n';
0062 CPPUNIT_ASSERT(output.item<float_t>() == 110.);
0063 std::cout << "ok\n";
0064 }