Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-06-07 02:29:49

0001 /*
0002  * Tests for working with empty inputs
0003  *
0004  */
0005 
0006 #include <stdexcept>
0007 #include <cppunit/extensions/HelperMacros.h>
0008 
0009 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0010 
0011 #include "testBase.h"
0012 
0013 class testEmptyInputs : public testBase {
0014   CPPUNIT_TEST_SUITE(testEmptyInputs);
0015   CPPUNIT_TEST(test);
0016   CPPUNIT_TEST_SUITE_END();
0017 
0018 public:
0019   std::string pyScript() const override;
0020   void test() override;
0021 };
0022 
0023 CPPUNIT_TEST_SUITE_REGISTRATION(testEmptyInputs);
0024 
0025 std::string testEmptyInputs::pyScript() const { return "createconstantgraph.py"; }
0026 
0027 void testEmptyInputs::test() {
0028   std::string pbFile = dataPath_ + "/constantgraph.pb";
0029 
0030   std::cout << "Testing CPU backend" << std::endl;
0031   tensorflow::Backend backend = tensorflow::Backend::cpu;
0032 
0033   // load the graph
0034   tensorflow::Options options{backend};
0035   tensorflow::GraphDef* graphDef = tensorflow::loadGraphDef(pbFile);
0036   CPPUNIT_ASSERT(graphDef != nullptr);
0037 
0038   // create a new session and add the graphDef
0039   const tensorflow::Session* session = tensorflow::createSession(graphDef, options);
0040   CPPUNIT_ASSERT(session != nullptr);
0041 
0042   // example evaluation with empty tensor
0043   tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 0});
0044   tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
0045   scale.scalar<float>()() = 1.0;
0046   std::vector<tensorflow::Tensor> outputs;
0047 
0048   // run using the convenience helper
0049   outputs.clear();
0050   tensorflow::run(session, {{"input", input}, {"scale", scale}}, {"output"}, &outputs);
0051   CPPUNIT_ASSERT(outputs.size() == 0);
0052 
0053   // cleanup
0054   CPPUNIT_ASSERT(tensorflow::closeSession(session));
0055   CPPUNIT_ASSERT(session == nullptr);
0056   delete graphDef;
0057 }