File indexing completed on 2024-09-24 22:51:34
0001
0002
0003
0004
0005
0006
0007 #include <stdexcept>
0008 #include <cppunit/extensions/HelperMacros.h>
0009
0010 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0011
0012 #include "testBaseCUDA.h"
0013
0014 class testSessionCacheCUDA : public testBaseCUDA {
0015 CPPUNIT_TEST_SUITE(testSessionCacheCUDA);
0016 CPPUNIT_TEST(test);
0017 CPPUNIT_TEST_SUITE_END();
0018
0019 public:
0020 std::string pyScript() const override;
0021 void test() override;
0022 };
0023
0024 CPPUNIT_TEST_SUITE_REGISTRATION(testSessionCacheCUDA);
0025
0026 std::string testSessionCacheCUDA::pyScript() const { return "createconstantgraph.py"; }
0027
0028 void testSessionCacheCUDA::test() {
0029 if (!cms::cudatest::testDevices())
0030 return;
0031
0032 std::vector<edm::ParameterSet> psets;
0033 edm::ServiceToken serviceToken = edm::ServiceRegistry::createSet(psets);
0034 edm::ServiceRegistry::Operate operate(serviceToken);
0035
0036
0037 edmplugin::PluginManager::configure(edmplugin::standard::config());
0038
0039 std::string const config = R"_(import FWCore.ParameterSet.Config as cms
0040 process = cms.Process('Test')
0041 process.add_(cms.Service('ResourceInformationService'))
0042 process.add_(cms.Service('CUDAService'))
0043 )_";
0044 std::unique_ptr<edm::ParameterSet> params;
0045 edm::makeParameterSets(config, params);
0046 edm::ServiceToken tempToken(edm::ServiceRegistry::createServicesFromConfig(std::move(params)));
0047 edm::ServiceRegistry::Operate operate2(tempToken);
0048 edm::Service<CUDAInterface> cuda;
0049 std::cout << "CUDA service enabled: " << cuda->enabled() << std::endl;
0050
0051 std::cout << "Testing CUDA backend" << std::endl;
0052 tensorflow::Backend backend = tensorflow::Backend::cuda;
0053
0054
0055 std::string pbFile = dataPath_ + "/constantgraph.pb";
0056 tensorflow::Options options{backend};
0057
0058
0059 tensorflow::SessionCache cache(pbFile, options);
0060
0061 CPPUNIT_ASSERT(cache.graph.load() != nullptr);
0062 CPPUNIT_ASSERT(cache.session.load() != nullptr);
0063
0064
0065 const tensorflow::Session* session = cache.getSession();
0066 CPPUNIT_ASSERT(session != nullptr);
0067
0068
0069 cache.closeSession();
0070 CPPUNIT_ASSERT(cache.graph.load() == nullptr);
0071 CPPUNIT_ASSERT(cache.session.load() == nullptr);
0072 }