File indexing completed on 2024-06-07 02:29:49
0001
0002
0003
0004
0005
0006 #include <stdexcept>
0007 #include <cppunit/extensions/HelperMacros.h>
0008
0009 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0010
0011 #include "testBase.h"
0012
0013 class testEmptyInputs : public testBase {
0014 CPPUNIT_TEST_SUITE(testEmptyInputs);
0015 CPPUNIT_TEST(test);
0016 CPPUNIT_TEST_SUITE_END();
0017
0018 public:
0019 std::string pyScript() const override;
0020 void test() override;
0021 };
0022
0023 CPPUNIT_TEST_SUITE_REGISTRATION(testEmptyInputs);
0024
0025 std::string testEmptyInputs::pyScript() const { return "createconstantgraph.py"; }
0026
0027 void testEmptyInputs::test() {
0028 std::string pbFile = dataPath_ + "/constantgraph.pb";
0029
0030 std::cout << "Testing CPU backend" << std::endl;
0031 tensorflow::Backend backend = tensorflow::Backend::cpu;
0032
0033
0034 tensorflow::Options options{backend};
0035 tensorflow::GraphDef* graphDef = tensorflow::loadGraphDef(pbFile);
0036 CPPUNIT_ASSERT(graphDef != nullptr);
0037
0038
0039 const tensorflow::Session* session = tensorflow::createSession(graphDef, options);
0040 CPPUNIT_ASSERT(session != nullptr);
0041
0042
0043 tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 0});
0044 tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
0045 scale.scalar<float>()() = 1.0;
0046 std::vector<tensorflow::Tensor> outputs;
0047
0048
0049 outputs.clear();
0050 tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs);
0051 CPPUNIT_ASSERT(outputs.size() == 0);
0052
0053
0054 CPPUNIT_ASSERT(tensorflow::closeSession(session));
0055 CPPUNIT_ASSERT(session == nullptr);
0056 delete graphDef;
0057 }