Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:15

0001 /*
0002  * Tests for loading meta graphs via the SavedModel interface.
0003  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0004  *
0005  * Author: Marcel Rieger
0006  */
0007 
0008 #include <stdexcept>
0009 #include <cppunit/extensions/HelperMacros.h>
0010 
0011 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0012 
0013 #include "testBase.h"
0014 
0015 class testMetaGraphLoading : public testBase {
0016   CPPUNIT_TEST_SUITE(testMetaGraphLoading);
0017   CPPUNIT_TEST(test);
0018   CPPUNIT_TEST_SUITE_END();
0019 
0020 public:
0021   std::string pyScript() const override;
0022   void test() override;
0023 };
0024 
0025 CPPUNIT_TEST_SUITE_REGISTRATION(testMetaGraphLoading);
0026 
0027 std::string testMetaGraphLoading::pyScript() const { return "creategraph.py"; }
0028 
0029 void testMetaGraphLoading::test() {
0030   std::string exportDir = dataPath_ + "/simplegraph";
0031 
0032   std::cout << "Testing CPU backend" << std::endl;
0033   tensorflow::Backend backend = tensorflow::Backend::cpu;
0034 
0035   // load the graph
0036   tensorflow::Options options{backend};
0037   tensorflow::MetaGraphDef* metaGraphDef = tensorflow::loadMetaGraphDef(exportDir);
0038   CPPUNIT_ASSERT(metaGraphDef != nullptr);
0039 
0040   // create a new, empty session
0041   tensorflow::Session* session1 = tensorflow::createSession(options);
0042   CPPUNIT_ASSERT(session1 != nullptr);
0043 
0044   // create a new session, using the meta graph
0045   tensorflow::Session* session2 = tensorflow::createSession(metaGraphDef, exportDir, options);
0046   CPPUNIT_ASSERT(session2 != nullptr);
0047 
0048   // check for exception
0049   CPPUNIT_ASSERT_THROW(tensorflow::createSession(nullptr, exportDir, options), cms::Exception);
0050 
0051   // example evaluation
0052   tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 10});
0053   float* d = input.flat<float>().data();
0054   for (size_t i = 0; i < 10; i++, d++) {
0055     *d = float(i);
0056   }
0057   tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
0058   scale.scalar<float>()() = 1.0;
0059 
0060   std::vector<tensorflow::Tensor> outputs;
0061   tensorflow::Status status = session2->Run({{"input", input}, {"scale", scale}}, {"output"}, {}, &outputs);
0062   if (!status.ok()) {
0063     std::cout << status.ToString() << std::endl;
0064     CPPUNIT_ASSERT(false);
0065   }
0066 
0067   // check the output
0068   CPPUNIT_ASSERT(outputs.size() == 1);
0069   std::cout << outputs[0].DebugString() << std::endl;
0070   CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0071 
0072   // run again using the convenience helper
0073   outputs.clear();
0074   tensorflow::run(session2, {{"input", input}, {"scale", scale}}, {"output"}, &outputs);
0075   CPPUNIT_ASSERT(outputs.size() == 1);
0076   std::cout << outputs[0].DebugString() << std::endl;
0077   CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0078 
0079   // check for exception
0080   CPPUNIT_ASSERT_THROW(tensorflow::run(session2, {{"foo", input}}, {"output"}, &outputs), cms::Exception);
0081 
0082   // cleanup
0083   CPPUNIT_ASSERT(tensorflow::closeSession(session1));
0084   CPPUNIT_ASSERT(tensorflow::closeSession(session2));
0085   delete metaGraphDef;
0086 }