File indexing completed on 2023-04-13 23:19:20
0001
0002
0003
0004
0005
0006
0007
0008 #include <cppunit/extensions/HelperMacros.h>
0009 #include <stdexcept>
0010
0011 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0012
0013 #include "testBase.h"
0014
0015 class testGraphLoading : public testBase {
0016 CPPUNIT_TEST_SUITE(testGraphLoading);
0017 CPPUNIT_TEST(test);
0018 CPPUNIT_TEST_SUITE_END();
0019
0020 public:
0021 std::string pyScript() const override;
0022 void test() override;
0023 };
0024
0025 CPPUNIT_TEST_SUITE_REGISTRATION(testGraphLoading);
0026
0027 std::string testGraphLoading::pyScript() const { return "createconstantgraph.py"; }
0028
0029 void testGraphLoading::test() {
0030 std::string pbFile = dataPath_ + "/constantgraph.pb";
0031
0032 std::cout << "Testing CPU backend" << std::endl;
0033 tensorflow::Backend backend = tensorflow::Backend::cpu;
0034 tensorflow::Options options{backend};
0035
0036
0037 int nThreads = 4;
0038 tensorflow::TBBThreadPool::instance(nThreads);
0039 options.setThreading(nThreads);
0040
0041
0042 tensorflow::GraphDef* graphDef = tensorflow::loadGraphDef(pbFile);
0043 CPPUNIT_ASSERT(graphDef != nullptr);
0044
0045
0046 tensorflow::Session* session = tensorflow::createSession(graphDef, options);
0047 CPPUNIT_ASSERT(session != nullptr);
0048
0049
0050 tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 10});
0051 float* d = input.flat<float>().data();
0052 for (size_t i = 0; i < 10; i++, d++) {
0053 *d = float(i);
0054 }
0055 tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
0056 scale.scalar<float>()() = 1.0;
0057
0058
0059 std::vector<tensorflow::Tensor> outputs;
0060 tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs, "no_threads");
0061 CPPUNIT_ASSERT(outputs.size() == 1);
0062 std::cout << outputs[0].DebugString() << std::endl;
0063 CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0064
0065
0066 outputs.clear();
0067 tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs, "tbb");
0068 CPPUNIT_ASSERT(outputs.size() == 1);
0069 std::cout << outputs[0].DebugString() << std::endl;
0070 CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0071
0072
0073 tensorflow::Session* session2 = tensorflow::createSession(graphDef, options);
0074 CPPUNIT_ASSERT(session != nullptr);
0075 outputs.clear();
0076 tensorflow::run(session2, {{"input", input}, {"scale", scale}}, {"output"}, &outputs, "tensorflow");
0077 CPPUNIT_ASSERT(outputs.size() == 1);
0078 std::cout << outputs[0].DebugString() << std::endl;
0079 CPPUNIT_ASSERT(outputs[0].matrix<float>()(0, 0) == 46.);
0080
0081
0082 CPPUNIT_ASSERT_THROW(
0083 tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs, "not_existing"),
0084 cms::Exception);
0085
0086
0087 CPPUNIT_ASSERT(tensorflow::closeSession(session));
0088 CPPUNIT_ASSERT(tensorflow::closeSession(session2));
0089 delete graphDef;
0090 }