Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-04-13 23:19:19

0001 /*
0002  * HelloWorld test of the TensorFlow interface.
0003  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0004  *
0005  * Author: Marcel Rieger
0006  */
0007 
0008 #include <stdexcept>
0009 #include <cppunit/extensions/HelperMacros.h>
0010 
0011 #include "tensorflow/cc/saved_model/loader.h"
0012 #include "tensorflow/cc/saved_model/tag_constants.h"
0013 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0014 
0015 #include "testBase.h"
0016 
0017 class testHelloWorld : public testBase {
0018   CPPUNIT_TEST_SUITE(testHelloWorld);
0019   CPPUNIT_TEST(test);
0020   CPPUNIT_TEST_SUITE_END();
0021 
0022 public:
0023   std::string pyScript() const override;
0024   void test() override;
0025 };
0026 
0027 CPPUNIT_TEST_SUITE_REGISTRATION(testHelloWorld);
0028 
0029 std::string testHelloWorld::pyScript() const { return "creategraph.py"; }
0030 
0031 void testHelloWorld::test() {
0032   std::string modelDir = dataPath_ + "/simplegraph";
0033   // Testing CPU
0034   std::cout << "Testing CPU backend" << std::endl;
0035   tensorflow::Backend backend = tensorflow::Backend::cpu;
0036 
0037   // object to load and run the graph / session
0038   tensorflow::Status status;
0039   tensorflow::Options options{backend};
0040   tensorflow::setLogging();
0041   tensorflow::RunOptions runOptions;
0042   tensorflow::SavedModelBundle bundle;
0043 
0044   // load everything
0045   status = tensorflow::LoadSavedModel(options.getSessionOptions(), runOptions, modelDir, {"serve"}, &bundle);
0046   if (!status.ok()) {
0047     std::cout << status.ToString() << std::endl;
0048     return;
0049   }
0050 
0051   // fetch the session
0052   tensorflow::Session* session = bundle.session.release();
0053 
0054   // prepare inputs
0055   tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 10});
0056   float* d = input.flat<float>().data();
0057   for (size_t i = 0; i < 10; i++, d++) {
0058     *d = float(i);
0059   }
0060   tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
0061   scale.scalar<float>()() = 1.0;
0062 
0063   // prepare outputs
0064   std::vector<tensorflow::Tensor> outputs;
0065 
0066   // session run
0067   status = session->Run({{"input", input}, {"scale", scale}}, {"output"}, {}, &outputs);
0068   if (!status.ok()) {
0069     std::cout << status.ToString() << std::endl;
0070     return;
0071   }
0072 
0073   // log the output tensor
0074   std::cout << outputs[0].DebugString() << std::endl;
0075 
0076   // close the session
0077   status = session->Close();
0078   if (!status.ok()) {
0079     std::cerr << "error while closing session" << std::endl;
0080   }
0081   delete session;
0082 }