Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:16

0001 /*
0002  * Tests for interacting with the SessionCache.
0003  *
0004  * Author: Marcel Rieger
0005  */
0006 
0007 #include <stdexcept>
0008 #include <cppunit/extensions/HelperMacros.h>
0009 
0010 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0011 
0012 #include "testBaseCUDA.h"
0013 
0014 class testSessionCacheCUDA : public testBaseCUDA {
0015   CPPUNIT_TEST_SUITE(testSessionCacheCUDA);
0016   CPPUNIT_TEST(test);
0017   CPPUNIT_TEST_SUITE_END();
0018 
0019 public:
0020   std::string pyScript() const override;
0021   void test() override;
0022 };
0023 
0024 CPPUNIT_TEST_SUITE_REGISTRATION(testSessionCacheCUDA);
0025 
0026 std::string testSessionCacheCUDA::pyScript() const { return "createconstantgraph.py"; }
0027 
0028 void testSessionCacheCUDA::test() {
0029   if (!cms::cudatest::testDevices())
0030     return;
0031 
0032   std::vector<edm::ParameterSet> psets;
0033   edm::ServiceToken serviceToken = edm::ServiceRegistry::createSet(psets);
0034   edm::ServiceRegistry::Operate operate(serviceToken);
0035 
0036   // Setup the CUDA Service
0037   edmplugin::PluginManager::configure(edmplugin::standard::config());
0038 
0039   std::string const config = R"_(import FWCore.ParameterSet.Config as cms
0040 process = cms.Process('Test')
0041 process.add_(cms.Service('ResourceInformationService'))
0042 process.add_(cms.Service('CUDAService'))
0043 )_";
0044   std::unique_ptr<edm::ParameterSet> params;
0045   edm::makeParameterSets(config, params);
0046   edm::ServiceToken tempToken(edm::ServiceRegistry::createServicesFromConfig(std::move(params)));
0047   edm::ServiceRegistry::Operate operate2(tempToken);
0048   edm::Service<CUDAInterface> cuda;
0049   std::cout << "CUDA service enabled: " << cuda->enabled() << std::endl;
0050 
0051   std::cout << "Testing CUDA backend" << std::endl;
0052   tensorflow::Backend backend = tensorflow::Backend::cuda;
0053 
0054   // load the graph and the session
0055   std::string pbFile = dataPath_ + "/constantgraph.pb";
0056   tensorflow::setLogging();
0057   tensorflow::Options options{backend};
0058 
0059   // load the graph and the session
0060   tensorflow::SessionCache cache(pbFile, options);
0061 
0062   CPPUNIT_ASSERT(cache.graph.load() != nullptr);
0063   CPPUNIT_ASSERT(cache.session.load() != nullptr);
0064 
0065   // get a const session pointer
0066   const tensorflow::Session* session = cache.getSession();
0067   CPPUNIT_ASSERT(session != nullptr);
0068 
0069   // cleanup
0070   cache.closeSession();
0071   CPPUNIT_ASSERT(cache.graph.load() == nullptr);
0072   CPPUNIT_ASSERT(cache.session.load() == nullptr);
0073 }