Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-09-24 22:51:34

0001 /*
0002  * Tests for loading meta graphs via the SavedModel interface.
0003  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0004  *
0005  * Author: Marcel Rieger
0006  */
0007 
0008 #include <stdexcept>
0009 #include <cppunit/extensions/HelperMacros.h>
0010 
0011 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0012 
0013 #include "testBaseCUDA.h"
0014 
0015 class testMetaGraphLoadingCUDA : public testBaseCUDA {
0016   CPPUNIT_TEST_SUITE(testMetaGraphLoadingCUDA);
0017   CPPUNIT_TEST(test);
0018   CPPUNIT_TEST_SUITE_END();
0019 
0020 public:
0021   std::string pyScript() const override;
0022   void test() override;
0023 };
0024 
0025 CPPUNIT_TEST_SUITE_REGISTRATION(testMetaGraphLoadingCUDA);
0026 
0027 std::string testMetaGraphLoadingCUDA::pyScript() const { return "creategraph.py"; }
0028 
0029 void testMetaGraphLoadingCUDA::test() {
0030   if (!cms::cudatest::testDevices())
0031     return;
0032 
0033   std::vector<edm::ParameterSet> psets;
0034   edm::ServiceToken serviceToken = edm::ServiceRegistry::createSet(psets);
0035   edm::ServiceRegistry::Operate operate(serviceToken);
0036 
0037   // Setup the CUDA Service
0038   edmplugin::PluginManager::configure(edmplugin::standard::config());
0039 
0040   std::string const config = R"_(import FWCore.ParameterSet.Config as cms
0041 process = cms.Process('Test')
0042 process.add_(cms.Service('ResourceInformationService'))
0043 process.add_(cms.Service('CUDAService'))
0044 )_";
0045   std::unique_ptr<edm::ParameterSet> params;
0046   edm::makeParameterSets(config, params);
0047   edm::ServiceToken tempToken(edm::ServiceRegistry::createServicesFromConfig(std::move(params)));
0048   edm::ServiceRegistry::Operate operate2(tempToken);
0049   edm::Service<CUDAInterface> cuda;
0050   std::cout << "CUDA service enabled: " << cuda->enabled() << std::endl;
0051 
0052   std::cout << "Testing CUDA backend" << std::endl;
0053   tensorflow::Backend backend = tensorflow::Backend::cuda;
0054 
0055   // load the graph
0056   std::string exportDir = dataPath_ + "/simplegraph";
0057   tensorflow::Options options{backend};
0058   tensorflow::MetaGraphDef* metaGraphDef = tensorflow::loadMetaGraphDef(exportDir);
0059   CPPUNIT_ASSERT(metaGraphDef != nullptr);
0060 
0061   // create a new, empty session
0062   tensorflow::Session* session1 = tensorflow::createSession(options);
0063   CPPUNIT_ASSERT(session1 != nullptr);
0064 
0065   // create a new session, using the meta graph
0066   tensorflow::Session* session2 = tensorflow::createSession(metaGraphDef, exportDir, options);
0067   CPPUNIT_ASSERT(session2 != nullptr);
0068 
0069   // check for exception
0070   CPPUNIT_ASSERT_THROW(tensorflow::createSession(nullptr, exportDir, options), cms::Exception);
0071 
0072   // example evaluation
0073   tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 10});
0074   float* d = input.flat<float>().data();
0075   for (size_t i = 0; i < 10; i++, d++) {
0076     *d = float(i);
0077   }
0078   tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
0079   scale.scalar<float>()() = 1.0;
0080 
0081   std::vector<tensorflow::Tensor> outputs;
0082   tensorflow::Status status = session2->Run({{"input", input}, {"scale", scale}}, {"output"}, {}, &outputs);
0083   if (!status.ok()) {
0084     std::cout << status.ToString() << std::endl;
0085     CPPUNIT_ASSERT(false);
0086   }
0087 
0088   // check the output
0089   CPPUNIT_ASSERT(outputs.size() == 1);
0090   std::cout << outputs[0].DebugString() << std::endl;
0091   CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0092 
0093   // run again using the convenience helper
0094   outputs.clear();
0095   tensorflow::run(session2, {{"input", input}, {"scale", scale}}, {"output"}, &outputs);
0096   CPPUNIT_ASSERT(outputs.size() == 1);
0097   std::cout << outputs[0].DebugString() << std::endl;
0098   CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0099 
0100   // check for exception
0101   CPPUNIT_ASSERT_THROW(tensorflow::run(session2, {{"foo", input}}, {"output"}, &outputs), cms::Exception);
0102 
0103   // cleanup
0104   CPPUNIT_ASSERT(tensorflow::closeSession(session1));
0105   CPPUNIT_ASSERT(tensorflow::closeSession(session2));
0106   delete metaGraphDef;
0107 }