Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-02-14 13:33:24

0001 /*
0002  * Tests for running inference using custom thread pools.
0003  * Based on TensorFlow 2.1.
0004  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0005  *
0006  * Author: Marcel Rieger
0007  */
0008 
0009 #include <cppunit/extensions/HelperMacros.h>
0010 #include <stdexcept>
0011 
0012 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0013 
0014 #include "testBase.h"
0015 
0016 class testGraphLoading : public testBase {
0017   CPPUNIT_TEST_SUITE(testGraphLoading);
0018   CPPUNIT_TEST(checkAll);
0019   CPPUNIT_TEST_SUITE_END();
0020 
0021 public:
0022   std::string pyScript() const override;
0023   void checkAll() override;
0024 };
0025 
0026 CPPUNIT_TEST_SUITE_REGISTRATION(testGraphLoading);
0027 
0028 std::string testGraphLoading::pyScript() const { return "createconstantgraph.py"; }
0029 
0030 void testGraphLoading::checkAll() {
0031   std::string pbFile = dataPath_ + "/constantgraph.pb";
0032 
0033   // initialize the TBB threadpool
0034   int nThreads = 4;
0035   tensorflow::TBBThreadPool::instance(nThreads);
0036 
0037   // load the graph
0038   tensorflow::setLogging();
0039   tensorflow::GraphDef* graphDef = tensorflow::loadGraphDef(pbFile);
0040   CPPUNIT_ASSERT(graphDef != nullptr);
0041 
0042   // create a new session and add the graphDef
0043   tensorflow::Session* session = tensorflow::createSession(graphDef);
0044   CPPUNIT_ASSERT(session != nullptr);
0045 
0046   // prepare inputs
0047   tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 10});
0048   float* d = input.flat<float>().data();
0049   for (size_t i = 0; i < 10; i++, d++) {
0050     *d = float(i);
0051   }
0052   tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
0053   scale.scalar<float>()() = 1.0;
0054 
0055   // "no_threads" pool
0056   std::vector<tensorflow::Tensor> outputs;
0057   tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs, "no_threads");
0058   CPPUNIT_ASSERT(outputs.size() == 1);
0059   std::cout << outputs[0].DebugString() << std::endl;
0060   CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0061 
0062   // "tbb" pool
0063   outputs.clear();
0064   tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs, "tbb");
0065   CPPUNIT_ASSERT(outputs.size() == 1);
0066   std::cout << outputs[0].DebugString() << std::endl;
0067   CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0068 
0069   // tensorflow defaut pool using a new session
0070   tensorflow::Session* session2 = tensorflow::createSession(graphDef, nThreads);
0071   CPPUNIT_ASSERT(session != nullptr);
0072   outputs.clear();
0073   tensorflow::run(session2, {{"input", input}, {"scale", scale}}, {"output"}, &outputs, "tensorflow");
0074   CPPUNIT_ASSERT(outputs.size() == 1);
0075   std::cout << outputs[0].DebugString() << std::endl;
0076   CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0077 
0078   // force an exception
0079   CPPUNIT_ASSERT_THROW(
0080       tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs, "not_existing"),
0081       cms::Exception);
0082 
0083   // cleanup
0084   CPPUNIT_ASSERT(tensorflow::closeSession(session));
0085   CPPUNIT_ASSERT(tensorflow::closeSession(session2));
0086   delete graphDef;
0087 }