Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-02-14 13:33:24

0001 /*
0002  * Tests for loading meta graphs via the SavedModel interface.
0003  * Based on TensorFlow 2.1.
0004  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0005  *
0006  * Author: Marcel Rieger
0007  */
0008 
0009 #include <stdexcept>
0010 #include <cppunit/extensions/HelperMacros.h>
0011 
0012 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0013 
0014 #include "testBase.h"
0015 
0016 class testMetaGraphLoading : public testBase {
0017   CPPUNIT_TEST_SUITE(testMetaGraphLoading);
0018   CPPUNIT_TEST(checkAll);
0019   CPPUNIT_TEST_SUITE_END();
0020 
0021 public:
0022   std::string pyScript() const override;
0023   void checkAll() override;
0024 };
0025 
0026 CPPUNIT_TEST_SUITE_REGISTRATION(testMetaGraphLoading);
0027 
0028 std::string testMetaGraphLoading::pyScript() const { return "creategraph.py"; }
0029 
0030 void testMetaGraphLoading::checkAll() {
0031   std::string exportDir = dataPath_ + "/simplegraph";
0032 
0033   // load the graph
0034   tensorflow::setLogging();
0035   tensorflow::MetaGraphDef* metaGraphDef = tensorflow::loadMetaGraphDef(exportDir);
0036   CPPUNIT_ASSERT(metaGraphDef != nullptr);
0037 
0038   // create a new, empty session
0039   tensorflow::Session* session1 = tensorflow::createSession();
0040   CPPUNIT_ASSERT(session1 != nullptr);
0041 
0042   // create a new session, using the meta graph
0043   tensorflow::Session* session2 = tensorflow::createSession(metaGraphDef, exportDir);
0044   CPPUNIT_ASSERT(session2 != nullptr);
0045 
0046   // check for exception
0047   CPPUNIT_ASSERT_THROW(tensorflow::createSession(nullptr, exportDir), cms::Exception);
0048 
0049   // example evaluation
0050   tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 10});
0051   float* d = input.flat<float>().data();
0052   for (size_t i = 0; i < 10; i++, d++) {
0053     *d = float(i);
0054   }
0055   tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
0056   scale.scalar<float>()() = 1.0;
0057 
0058   std::vector<tensorflow::Tensor> outputs;
0059   tensorflow::Status status = session2->Run({{"input", input}, {"scale", scale}}, {"output"}, {}, &outputs);
0060   if (!status.ok()) {
0061     std::cout << status.ToString() << std::endl;
0062     CPPUNIT_ASSERT(false);
0063   }
0064 
0065   // check the output
0066   CPPUNIT_ASSERT(outputs.size() == 1);
0067   std::cout << outputs[0].DebugString() << std::endl;
0068   CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0069 
0070   // run again using the convenience helper
0071   outputs.clear();
0072   tensorflow::run(session2, {{"input", input}, {"scale", scale}}, {"output"}, &outputs);
0073   CPPUNIT_ASSERT(outputs.size() == 1);
0074   std::cout << outputs[0].DebugString() << std::endl;
0075   CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0076 
0077   // check for exception
0078   CPPUNIT_ASSERT_THROW(tensorflow::run(session2, {{"foo", input}}, {"output"}, &outputs), cms::Exception);
0079 
0080   // cleanup
0081   CPPUNIT_ASSERT(tensorflow::closeSession(session1));
0082   CPPUNIT_ASSERT(tensorflow::closeSession(session2));
0083   delete metaGraphDef;
0084 }