File indexing completed on 2024-04-06 12:24:15
0001
0002
0003
0004
0005
0006
0007
0008 #include <stdexcept>
0009 #include <cppunit/extensions/HelperMacros.h>
0010
0011 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0012
0013 #include "testBase.h"
0014
0015 class testMetaGraphLoading : public testBase {
0016 CPPUNIT_TEST_SUITE(testMetaGraphLoading);
0017 CPPUNIT_TEST(test);
0018 CPPUNIT_TEST_SUITE_END();
0019
0020 public:
0021 std::string pyScript() const override;
0022 void test() override;
0023 };
0024
0025 CPPUNIT_TEST_SUITE_REGISTRATION(testMetaGraphLoading);
0026
0027 std::string testMetaGraphLoading::pyScript() const { return "creategraph.py"; }
0028
0029 void testMetaGraphLoading::test() {
0030 std::string exportDir = dataPath_ + "/simplegraph";
0031
0032 std::cout << "Testing CPU backend" << std::endl;
0033 tensorflow::Backend backend = tensorflow::Backend::cpu;
0034
0035
0036 tensorflow::Options options{backend};
0037 tensorflow::MetaGraphDef* metaGraphDef = tensorflow::loadMetaGraphDef(exportDir);
0038 CPPUNIT_ASSERT(metaGraphDef != nullptr);
0039
0040
0041 tensorflow::Session* session1 = tensorflow::createSession(options);
0042 CPPUNIT_ASSERT(session1 != nullptr);
0043
0044
0045 tensorflow::Session* session2 = tensorflow::createSession(metaGraphDef, exportDir, options);
0046 CPPUNIT_ASSERT(session2 != nullptr);
0047
0048
0049 CPPUNIT_ASSERT_THROW(tensorflow::createSession(nullptr, exportDir, options), cms::Exception);
0050
0051
0052 tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 10});
0053 float* d = input.flat<float>().data();
0054 for (size_t i = 0; i < 10; i++, d++) {
0055 *d = float(i);
0056 }
0057 tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
0058 scale.scalar<float>()() = 1.0;
0059
0060 std::vector<tensorflow::Tensor> outputs;
0061 tensorflow::Status status = session2->Run({{"input", input}, {"scale", scale}}, {"output"}, {}, &outputs);
0062 if (!status.ok()) {
0063 std::cout << status.ToString() << std::endl;
0064 CPPUNIT_ASSERT(false);
0065 }
0066
0067
0068 CPPUNIT_ASSERT(outputs.size() == 1);
0069 std::cout << outputs[0].DebugString() << std::endl;
0070 CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0071
0072
0073 outputs.clear();
0074 tensorflow::run(session2, {{"input", input}, {"scale", scale}}, {"output"}, &outputs);
0075 CPPUNIT_ASSERT(outputs.size() == 1);
0076 std::cout << outputs[0].DebugString() << std::endl;
0077 CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0078
0079
0080 CPPUNIT_ASSERT_THROW(tensorflow::run(session2, {{"foo", input}}, {"output"}, &outputs), cms::Exception);
0081
0082
0083 CPPUNIT_ASSERT(tensorflow::closeSession(session1));
0084 CPPUNIT_ASSERT(tensorflow::closeSession(session2));
0085 delete metaGraphDef;
0086 }