Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:16

0001 /*
0002  * Tests for interacting with the SessionCache.
0003  *
0004  * Author: Marcel Rieger
0005  */
0006 
0007 #include <stdexcept>
0008 #include <cppunit/extensions/HelperMacros.h>
0009 
0010 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0011 
0012 #include "testBase.h"
0013 
0014 class testSessionCache : public testBase {
0015   CPPUNIT_TEST_SUITE(testSessionCache);
0016   CPPUNIT_TEST(test);
0017   CPPUNIT_TEST_SUITE_END();
0018 
0019 public:
0020   std::string pyScript() const override;
0021   void test() override;
0022 };
0023 
0024 CPPUNIT_TEST_SUITE_REGISTRATION(testSessionCache);
0025 
0026 std::string testSessionCache::pyScript() const { return "createconstantgraph.py"; }
0027 
0028 void testSessionCache::test() {
0029   std::string pbFile = dataPath_ + "/constantgraph.pb";
0030 
0031   std::cout << "Testing CPU backend" << std::endl;
0032   tensorflow::Backend backend = tensorflow::Backend::cpu;
0033   tensorflow::Options options{backend};
0034 
0035   // load the graph and the session
0036   tensorflow::SessionCache cache(pbFile, options);
0037   CPPUNIT_ASSERT(cache.graph.load() != nullptr);
0038   CPPUNIT_ASSERT(cache.session.load() != nullptr);
0039 
0040   // get a const session pointer
0041   const tensorflow::Session* session = cache.getSession();
0042   CPPUNIT_ASSERT(session != nullptr);
0043 
0044   // cleanup
0045   cache.closeSession();
0046   CPPUNIT_ASSERT(cache.graph.load() == nullptr);
0047   CPPUNIT_ASSERT(cache.session.load() == nullptr);
0048 }