File indexing completed on 2023-04-13 23:19:20
0001
0002
0003
0004
0005
0006
0007 #include <stdexcept>
0008 #include <cppunit/extensions/HelperMacros.h>
0009
0010 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0011
0012 #include "testBase.h"
0013
0014 class testSessionCache : public testBase {
0015 CPPUNIT_TEST_SUITE(testSessionCache);
0016 CPPUNIT_TEST(test);
0017 CPPUNIT_TEST_SUITE_END();
0018
0019 public:
0020 std::string pyScript() const override;
0021 void test() override;
0022 };
0023
0024 CPPUNIT_TEST_SUITE_REGISTRATION(testSessionCache);
0025
0026 std::string testSessionCache::pyScript() const { return "createconstantgraph.py"; }
0027
0028 void testSessionCache::test() {
0029 std::string pbFile = dataPath_ + "/constantgraph.pb";
0030
0031 std::cout << "Testing CPU backend" << std::endl;
0032 tensorflow::Backend backend = tensorflow::Backend::cpu;
0033 tensorflow::Options options{backend};
0034
0035
0036 tensorflow::SessionCache cache(pbFile, options);
0037 CPPUNIT_ASSERT(cache.graph.load() != nullptr);
0038 CPPUNIT_ASSERT(cache.session.load() != nullptr);
0039
0040
0041 const tensorflow::Session* session = cache.getSession();
0042 CPPUNIT_ASSERT(session != nullptr);
0043
0044
0045 cache.closeSession();
0046 CPPUNIT_ASSERT(cache.graph.load() == nullptr);
0047 CPPUNIT_ASSERT(cache.session.load() == nullptr);
0048 }