File indexing completed on 2023-04-13 23:19:19
0001
0002
0003
0004
0005
0006
0007
0008 #include <stdexcept>
0009 #include <cppunit/extensions/HelperMacros.h>
0010
0011 #include "tensorflow/cc/saved_model/loader.h"
0012 #include "tensorflow/cc/saved_model/tag_constants.h"
0013 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0014
0015 #include "testBase.h"
0016
0017 class testHelloWorld : public testBase {
0018 CPPUNIT_TEST_SUITE(testHelloWorld);
0019 CPPUNIT_TEST(test);
0020 CPPUNIT_TEST_SUITE_END();
0021
0022 public:
0023 std::string pyScript() const override;
0024 void test() override;
0025 };
0026
0027 CPPUNIT_TEST_SUITE_REGISTRATION(testHelloWorld);
0028
0029 std::string testHelloWorld::pyScript() const { return "creategraph.py"; }
0030
0031 void testHelloWorld::test() {
0032 std::string modelDir = dataPath_ + "/simplegraph";
0033
0034 std::cout << "Testing CPU backend" << std::endl;
0035 tensorflow::Backend backend = tensorflow::Backend::cpu;
0036
0037
0038 tensorflow::Status status;
0039 tensorflow::Options options{backend};
0040 tensorflow::setLogging();
0041 tensorflow::RunOptions runOptions;
0042 tensorflow::SavedModelBundle bundle;
0043
0044
0045 status = tensorflow::LoadSavedModel(options.getSessionOptions(), runOptions, modelDir, {"serve"}, &bundle);
0046 if (!status.ok()) {
0047 std::cout << status.ToString() << std::endl;
0048 return;
0049 }
0050
0051
0052 tensorflow::Session* session = bundle.session.release();
0053
0054
0055 tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, 10});
0056 float* d = input.flat<float>().data();
0057 for (size_t i = 0; i < 10; i++, d++) {
0058 *d = float(i);
0059 }
0060 tensorflow::Tensor scale(tensorflow::DT_FLOAT, {});
0061 scale.scalar<float>()() = 1.0;
0062
0063
0064 std::vector<tensorflow::Tensor> outputs;
0065
0066
0067 status = session->Run({{"input", input}, {"scale", scale}}, {"output"}, {}, &outputs);
0068 if (!status.ok()) {
0069 std::cout << status.ToString() << std::endl;
0070 return;
0071 }
0072
0073
0074 std::cout << outputs[0].DebugString() << std::endl;
0075
0076
0077 status = session->Close();
0078 if (!status.ok()) {
0079 std::cerr << "error while closing session" << std::endl;
0080 }
0081 delete session;
0082 }