File indexing completed on 2024-09-07 04:37:24
0001
0002
0003
0004
0005 #include <stdexcept>
0006 #include <cppunit/extensions/HelperMacros.h>
0007
0008 #include "PhysicsTools/TensorFlowAOT/interface/Model.h"
0009
0010 #include "tfaot-model-test-simple/model.h"
0011 #include "tfaot-model-test-multi/model.h"
0012
0013 class testInterface : public CppUnit::TestFixture {
0014 CPPUNIT_TEST_SUITE(testInterface);
0015 CPPUNIT_TEST(test);
0016 CPPUNIT_TEST_SUITE_END();
0017
0018 public:
0019 void setUp() {}
0020 void tearDown() {}
0021 void test();
0022 void test_simple();
0023 void test_multi();
0024 };
0025
0026 CPPUNIT_TEST_SUITE_REGISTRATION(testInterface);
0027
0028 void testInterface::test() {
0029 test_simple();
0030 test_multi();
0031 }
0032
0033 void testInterface::test_simple() {
0034 std::cout << std::endl;
0035 std::cout << "tesing simple model" << std::endl;
0036
0037
0038 auto model = tfaot::Model<tfaot_model::test_simple>();
0039
0040
0041 model.setBatchRule(1, {1});
0042 model.setBatchRule(3, {2, 2}, 1);
0043 model.setBatchRule("5:2,2,2");
0044
0045
0046 CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(1));
0047 CPPUNIT_ASSERT(model.getBatchStrategy().getRule(1).nSizes() == 1);
0048 CPPUNIT_ASSERT(model.getBatchStrategy().getRule(1).getLastPadding() == 0);
0049 CPPUNIT_ASSERT(!model.getBatchStrategy().hasRule(2));
0050 CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(3));
0051 CPPUNIT_ASSERT(model.getBatchStrategy().getRule(3).nSizes() == 2);
0052 CPPUNIT_ASSERT(model.getBatchStrategy().getRule(3).getLastPadding() == 1);
0053 CPPUNIT_ASSERT(!model.getBatchStrategy().hasRule(4));
0054 CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(5));
0055 CPPUNIT_ASSERT(model.getBatchStrategy().getRule(5).nSizes() == 3);
0056 CPPUNIT_ASSERT(model.getBatchStrategy().getRule(5).getLastPadding() == 1);
0057
0058
0059 tfaot::FloatArrays input_bs1 = {{0, 1, 2, 3}};
0060 tfaot::FloatArrays output_bs1;
0061 std::tie(output_bs1) = model.run<tfaot::FloatArrays>(1, input_bs1);
0062 CPPUNIT_ASSERT(output_bs1.size() == 1);
0063 CPPUNIT_ASSERT(output_bs1[0].size() == 2);
0064 std::cout << "output_bs1[0]: " << output_bs1[0][0] << ", " << output_bs1[0][1] << std::endl;
0065
0066
0067 tfaot::FloatArrays input_bs2 = {{0, 1, 2, 3}, {4, 5, 6, 7}};
0068 tfaot::FloatArrays output_bs2;
0069 std::tie(output_bs2) = model.run<tfaot::FloatArrays>(2, input_bs2);
0070 CPPUNIT_ASSERT(output_bs2.size() == 2);
0071 CPPUNIT_ASSERT(output_bs2[0].size() == 2);
0072 CPPUNIT_ASSERT(output_bs2[1].size() == 2);
0073 std::cout << "output_bs2[0]: " << output_bs2[0][0] << ", " << output_bs2[0][1] << std::endl;
0074 std::cout << "output_bs2[1]: " << output_bs2[1][0] << ", " << output_bs2[1][1] << std::endl;
0075
0076
0077 CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(2));
0078
0079
0080 tfaot::FloatArrays input_bs3 = {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}};
0081 tfaot::FloatArrays output_bs3;
0082 std::tie(output_bs3) = model.run<tfaot::FloatArrays>(3, input_bs3);
0083 CPPUNIT_ASSERT(output_bs3.size() == 3);
0084 CPPUNIT_ASSERT(output_bs3[0].size() == 2);
0085 CPPUNIT_ASSERT(output_bs3[1].size() == 2);
0086 CPPUNIT_ASSERT(output_bs3[2].size() == 2);
0087 std::cout << "output_bs3[0]: " << output_bs3[0][0] << ", " << output_bs3[0][1] << std::endl;
0088 std::cout << "output_bs3[1]: " << output_bs3[1][0] << ", " << output_bs3[1][1] << std::endl;
0089 std::cout << "output_bs3[2]: " << output_bs3[2][0] << ", " << output_bs3[2][1] << std::endl;
0090
0091
0092 CPPUNIT_ASSERT(!model.getBatchStrategy().hasRule(4));
0093 }
0094
0095 void testInterface::test_multi() {
0096 std::cout << std::endl;
0097 std::cout << "tesing multi model" << std::endl;
0098
0099
0100 auto model = tfaot::Model<tfaot_model::test_multi>();
0101
0102
0103 CPPUNIT_ASSERT(!model.getBatchStrategy().hasRule(2));
0104
0105
0106 tfaot::FloatArrays input1_bs2 = {{0, 1, 2, 3}, {4, 5, 6, 7}};
0107 tfaot::DoubleArrays input2_bs2 = {{0, 1, 2, 3}, {4, 5, 6, 7}};
0108 tfaot::FloatArrays output1_bs2;
0109 tfaot::BoolArrays output2_bs2;
0110 std::tie(output1_bs2, output2_bs2) = model.run<tfaot::FloatArrays, tfaot::BoolArrays>(2, input1_bs2, input2_bs2);
0111 CPPUNIT_ASSERT(output1_bs2.size() == 2);
0112 CPPUNIT_ASSERT(output1_bs2[0].size() == 2);
0113 CPPUNIT_ASSERT(output1_bs2[1].size() == 2);
0114 std::cout << "output1_bs2[0]: " << output1_bs2[0][0] << ", " << output1_bs2[0][1] << std::endl;
0115 std::cout << "output1_bs2[1]: " << output1_bs2[1][0] << ", " << output1_bs2[1][1] << std::endl;
0116 CPPUNIT_ASSERT(output2_bs2.size() == 2);
0117 CPPUNIT_ASSERT(output2_bs2[0].size() == 2);
0118 CPPUNIT_ASSERT(output2_bs2[1].size() == 2);
0119 std::cout << "output2_bs2[0]: " << output2_bs2[0][0] << ", " << output2_bs2[0][1] << std::endl;
0120 std::cout << "output2_bs2[1]: " << output2_bs2[1][0] << ", " << output2_bs2[1][1] << std::endl;
0121
0122
0123 CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(2));
0124 }