Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-04-13 23:19:20

0001 /*
0002  * Tests for running inference using custom thread pools.
0003  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0004  *
0005  * Author: Marcel Rieger
0006  */
0007 
0008 #include <cppunit/extensions/HelperMacros.h>
0009 #include <stdexcept>
0010 
0011 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0012 
0013 #include "testBase.h"
0014 
0015 class testGraphLoading : public testBase {
0016   CPPUNIT_TEST_SUITE(testGraphLoading);
0017   CPPUNIT_TEST(test);
0018   CPPUNIT_TEST_SUITE_END();
0019 
0020 public:
0021   std::string pyScript() const override;
0022   void test() override;
0023 };
0024 
0025 CPPUNIT_TEST_SUITE_REGISTRATION(testGraphLoading);
0026 
0027 std::string testGraphLoading::pyScript() const { return "createconstantgraph.py"; }
0028 
0029 void testGraphLoading::test() {
0030   std::string pbFile = dataPath_ + "/constantgraph.pb";
0031 
0032   std::cout << "Testing CPU backend" << std::endl;
0033   tensorflow::Backend backend = tensorflow::Backend::cpu;
0034   tensorflow::Options options{backend};
0035 
0036   // initialize the TBB threadpool
0037   int nThreads = 4;
0038   tensorflow::TBBThreadPool::instance(nThreads);
0039   options.setThreading(nThreads);
0040 
0041   // load the graph
0042   tensorflow::GraphDef* graphDef = tensorflow::loadGraphDef(pbFile);
0043   CPPUNIT_ASSERT(graphDef != nullptr);
0044 
0045   // create a new session and add the graphDef
0046   tensorflow::Session* session = tensorflow::createSession(graphDef, options);
0047   CPPUNIT_ASSERT(session != nullptr);
0048 
0049   // prepare inputs
0050   tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 10});
0051   float* d = input.flat<float>().data();
0052   for (size_t i = 0; i < 10; i++, d++) {
0053     *d = float(i);
0054   }
0055   tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
0056   scale.scalar<float>()() = 1.0;
0057 
0058   // "no_threads" pool
0059   std::vector<tensorflow::Tensor> outputs;
0060   tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs, "no_threads");
0061   CPPUNIT_ASSERT(outputs.size() == 1);
0062   std::cout << outputs[0].DebugString() << std::endl;
0063   CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0064 
0065   // "tbb" pool
0066   outputs.clear();
0067   tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs, "tbb");
0068   CPPUNIT_ASSERT(outputs.size() == 1);
0069   std::cout << outputs[0].DebugString() << std::endl;
0070   CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0071 
0072   // tensorflow defaut pool using a new session
0073   tensorflow::Session* session2 = tensorflow::createSession(graphDef, options);
0074   CPPUNIT_ASSERT(session != nullptr);
0075   outputs.clear();
0076   tensorflow::run(session2, {{"input", input}, {"scale", scale}}, {"output"}, &outputs, "tensorflow");
0077   CPPUNIT_ASSERT(outputs.size() == 1);
0078   std::cout << outputs[0].DebugString() << std::endl;
0079   CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0080 
0081   // force an exception
0082   CPPUNIT_ASSERT_THROW(
0083       tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs, "not_existing"),
0084       cms::Exception);
0085 
0086   // cleanup
0087   CPPUNIT_ASSERT(tensorflow::closeSession(session));
0088   CPPUNIT_ASSERT(tensorflow::closeSession(session2));
0089   delete graphDef;
0090 }