Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:15

0001 /*
0002  * HelloWorld test of the TensorFlow interface.
0003  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0004  *
0005  * Author: Marcel Rieger
0006  */
0007 
0008 #include <stdexcept>
0009 #include <cppunit/extensions/HelperMacros.h>
0010 
0011 #include "tensorflow/cc/saved_model/loader.h"
0012 #include "tensorflow/cc/saved_model/tag_constants.h"
0013 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0014 
0015 #include "testBaseCUDA.h"
0016 
0017 class testHelloWorldCUDA : public testBaseCUDA {
0018   CPPUNIT_TEST_SUITE(testHelloWorldCUDA);
0019   CPPUNIT_TEST(test);
0020   CPPUNIT_TEST_SUITE_END();
0021 
0022 public:
0023   std::string pyScript() const override;
0024   void test() override;
0025 };
0026 
0027 CPPUNIT_TEST_SUITE_REGISTRATION(testHelloWorldCUDA);
0028 
0029 std::string testHelloWorldCUDA::pyScript() const { return "creategraph.py"; }
0030 
0031 void testHelloWorldCUDA::test() {
0032   if (!cms::cudatest::testDevices())
0033     return;
0034 
0035   std::vector<edm::ParameterSet> psets;
0036   edm::ServiceToken serviceToken = edm::ServiceRegistry::createSet(psets);
0037   edm::ServiceRegistry::Operate operate(serviceToken);
0038 
0039   // Setup the CUDA Service
0040   edmplugin::PluginManager::configure(edmplugin::standard::config());
0041 
0042   std::string const config = R"_(import FWCore.ParameterSet.Config as cms
0043 process = cms.Process('Test')
0044 process.add_(cms.Service('ResourceInformationService'))
0045 process.add_(cms.Service('CUDAService'))
0046 )_";
0047   std::unique_ptr<edm::ParameterSet> params;
0048   edm::makeParameterSets(config, params);
0049   edm::ServiceToken tempToken(edm::ServiceRegistry::createServicesFromConfig(std::move(params)));
0050   edm::ServiceRegistry::Operate operate2(tempToken);
0051   edm::Service<CUDAInterface> cuda;
0052   std::cout << "CUDA service enabled: " << cuda->enabled() << std::endl;
0053 
0054   std::cout << "Testing CUDA backend" << std::endl;
0055   tensorflow::Backend backend = tensorflow::Backend::cuda;
0056 
0057   // object to load and run the graph / session
0058   tensorflow::Status status;
0059   tensorflow::Options options{backend};
0060   tensorflow::setLogging("0");
0061   tensorflow::RunOptions runOptions;
0062   tensorflow::SavedModelBundle bundle;
0063 
0064   // load everything
0065   std::string modelDir = dataPath_ + "/simplegraph";
0066   status = tensorflow::LoadSavedModel(options.getSessionOptions(), runOptions, modelDir, {"serve"}, &bundle);
0067   if (!status.ok()) {
0068     std::cout << status.ToString() << std::endl;
0069     return;
0070   }
0071 
0072   // fetch the session
0073   tensorflow::Session* session = bundle.session.release();
0074 
0075   // prepare inputs
0076   tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 10});
0077   float* d = input.flat<float>().data();
0078   for (size_t i = 0; i < 10; i++, d++) {
0079     *d = float(i);
0080   }
0081   tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
0082   scale.scalar<float>()() = 1.0;
0083 
0084   // prepare outputs
0085   std::vector<tensorflow::Tensor> outputs;
0086 
0087   // session run
0088   status = session->Run({{"input", input}, {"scale", scale}}, {"output"}, {}, &outputs);
0089   if (!status.ok()) {
0090     std::cout << status.ToString() << std::endl;
0091     return;
0092   }
0093 
0094   // log the output tensor
0095   std::cout << outputs[0].DebugString() << std::endl;
0096 
0097   // close the session
0098   status = session->Close();
0099   if (!status.ok()) {
0100     std::cerr << "error while closing session" << std::endl;
0101   }
0102   delete session;
0103 }