Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-09-07 04:37:24

0001 /*
0002  * AOT interface tests.
0003  */
0004 
0005 #include <stdexcept>
0006 #include <cppunit/extensions/HelperMacros.h>
0007 
0008 #include "PhysicsTools/TensorFlowAOT/interface/Model.h"
0009 
0010 #include "tfaot-model-test-simple/model.h"
0011 #include "tfaot-model-test-multi/model.h"
0012 
0013 class testInterface : public CppUnit::TestFixture {
0014   CPPUNIT_TEST_SUITE(testInterface);
0015   CPPUNIT_TEST(test);
0016   CPPUNIT_TEST_SUITE_END();
0017 
0018 public:
0019   void setUp() {}
0020   void tearDown() {}
0021   void test();
0022   void test_simple();
0023   void test_multi();
0024 };
0025 
0026 CPPUNIT_TEST_SUITE_REGISTRATION(testInterface);
0027 
0028 void testInterface::test() {
0029   test_simple();
0030   test_multi();
0031 }
0032 
0033 void testInterface::test_simple() {
0034   std::cout << std::endl;
0035   std::cout << "tesing simple model" << std::endl;
0036 
0037   // initialize the model
0038   auto model = tfaot::Model<tfaot_model::test_simple>();
0039 
0040   // register (optional) batch rules
0041   model.setBatchRule(1, {1});
0042   model.setBatchRule(3, {2, 2}, 1);
0043   model.setBatchRule("5:2,2,2");
0044 
0045   // test batching strategies
0046   CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(1));
0047   CPPUNIT_ASSERT(model.getBatchStrategy().getRule(1).nSizes() == 1);
0048   CPPUNIT_ASSERT(model.getBatchStrategy().getRule(1).getLastPadding() == 0);
0049   CPPUNIT_ASSERT(!model.getBatchStrategy().hasRule(2));
0050   CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(3));
0051   CPPUNIT_ASSERT(model.getBatchStrategy().getRule(3).nSizes() == 2);
0052   CPPUNIT_ASSERT(model.getBatchStrategy().getRule(3).getLastPadding() == 1);
0053   CPPUNIT_ASSERT(!model.getBatchStrategy().hasRule(4));
0054   CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(5));
0055   CPPUNIT_ASSERT(model.getBatchStrategy().getRule(5).nSizes() == 3);
0056   CPPUNIT_ASSERT(model.getBatchStrategy().getRule(5).getLastPadding() == 1);
0057 
0058   // evaluate batch size 1
0059   tfaot::FloatArrays input_bs1 = {{0, 1, 2, 3}};
0060   tfaot::FloatArrays output_bs1;
0061   std::tie(output_bs1) = model.run<tfaot::FloatArrays>(1, input_bs1);
0062   CPPUNIT_ASSERT(output_bs1.size() == 1);
0063   CPPUNIT_ASSERT(output_bs1[0].size() == 2);
0064   std::cout << "output_bs1[0]: " << output_bs1[0][0] << ", " << output_bs1[0][1] << std::endl;
0065 
0066   // evaluate batch size 2
0067   tfaot::FloatArrays input_bs2 = {{0, 1, 2, 3}, {4, 5, 6, 7}};
0068   tfaot::FloatArrays output_bs2;
0069   std::tie(output_bs2) = model.run<tfaot::FloatArrays>(2, input_bs2);
0070   CPPUNIT_ASSERT(output_bs2.size() == 2);
0071   CPPUNIT_ASSERT(output_bs2[0].size() == 2);
0072   CPPUNIT_ASSERT(output_bs2[1].size() == 2);
0073   std::cout << "output_bs2[0]: " << output_bs2[0][0] << ", " << output_bs2[0][1] << std::endl;
0074   std::cout << "output_bs2[1]: " << output_bs2[1][0] << ", " << output_bs2[1][1] << std::endl;
0075 
0076   // there must be a batch rule for size 2 now
0077   CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(2));
0078 
0079   // evaluate batch size 3
0080   tfaot::FloatArrays input_bs3 = {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}};
0081   tfaot::FloatArrays output_bs3;
0082   std::tie(output_bs3) = model.run<tfaot::FloatArrays>(3, input_bs3);
0083   CPPUNIT_ASSERT(output_bs3.size() == 3);
0084   CPPUNIT_ASSERT(output_bs3[0].size() == 2);
0085   CPPUNIT_ASSERT(output_bs3[1].size() == 2);
0086   CPPUNIT_ASSERT(output_bs3[2].size() == 2);
0087   std::cout << "output_bs3[0]: " << output_bs3[0][0] << ", " << output_bs3[0][1] << std::endl;
0088   std::cout << "output_bs3[1]: " << output_bs3[1][0] << ", " << output_bs3[1][1] << std::endl;
0089   std::cout << "output_bs3[2]: " << output_bs3[2][0] << ", " << output_bs3[2][1] << std::endl;
0090 
0091   // there still should be no batch rule for size 4
0092   CPPUNIT_ASSERT(!model.getBatchStrategy().hasRule(4));
0093 }
0094 
0095 void testInterface::test_multi() {
0096   std::cout << std::endl;
0097   std::cout << "tesing multi model" << std::endl;
0098 
0099   // initialize the model
0100   auto model = tfaot::Model<tfaot_model::test_multi>();
0101 
0102   // there should be no batch rule for size 2 yet
0103   CPPUNIT_ASSERT(!model.getBatchStrategy().hasRule(2));
0104 
0105   // evaluate batch size 2
0106   tfaot::FloatArrays input1_bs2 = {{0, 1, 2, 3}, {4, 5, 6, 7}};
0107   tfaot::DoubleArrays input2_bs2 = {{0, 1, 2, 3}, {4, 5, 6, 7}};
0108   tfaot::FloatArrays output1_bs2;
0109   tfaot::BoolArrays output2_bs2;
0110   std::tie(output1_bs2, output2_bs2) = model.run<tfaot::FloatArrays, tfaot::BoolArrays>(2, input1_bs2, input2_bs2);
0111   CPPUNIT_ASSERT(output1_bs2.size() == 2);
0112   CPPUNIT_ASSERT(output1_bs2[0].size() == 2);
0113   CPPUNIT_ASSERT(output1_bs2[1].size() == 2);
0114   std::cout << "output1_bs2[0]: " << output1_bs2[0][0] << ", " << output1_bs2[0][1] << std::endl;
0115   std::cout << "output1_bs2[1]: " << output1_bs2[1][0] << ", " << output1_bs2[1][1] << std::endl;
0116   CPPUNIT_ASSERT(output2_bs2.size() == 2);
0117   CPPUNIT_ASSERT(output2_bs2[0].size() == 2);
0118   CPPUNIT_ASSERT(output2_bs2[1].size() == 2);
0119   std::cout << "output2_bs2[0]: " << output2_bs2[0][0] << ", " << output2_bs2[0][1] << std::endl;
0120   std::cout << "output2_bs2[1]: " << output2_bs2[1][0] << ", " << output2_bs2[1][1] << std::endl;
0121 
0122   // there must be a batch rule for size 2 now
0123   CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(2));
0124 }