Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:15

0001 /*
0002  * TensorFlow AOT test
0003  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0004  *
0005  * Author: Marcel Rieger
0006  */
0007 
0008 #include <cppunit/extensions/HelperMacros.h>
0009 #include <stdexcept>
0010 
0011 #include "testAOT_add/header.h"
0012 
0013 using AddComp = testAOT_add;
0014 
0015 class testAOT : public CppUnit::TestFixture {
0016   CPPUNIT_TEST_SUITE(testAOT);
0017   CPPUNIT_TEST(checkAll);
0018   CPPUNIT_TEST_SUITE_END();
0019 
0020 public:
0021   void checkAll();
0022 };
0023 
0024 CPPUNIT_TEST_SUITE_REGISTRATION(testAOT);
0025 
0026 void testAOT::checkAll() {
0027   {
0028     std::cout << "testing tf add" << std::endl;
0029     AddComp add;
0030 
0031     add.arg0() = 1;
0032     add.arg1() = 2;
0033     CPPUNIT_ASSERT(add.Run());
0034     CPPUNIT_ASSERT(add.error_msg() == "");
0035     CPPUNIT_ASSERT(add.result0() == 3);
0036     CPPUNIT_ASSERT(add.result0_data()[0] == 3);
0037     CPPUNIT_ASSERT(add.result0_data() == add.results()[0]);
0038 
0039     add.arg0_data()[0] = 123;
0040     add.arg1_data()[0] = 456;
0041     CPPUNIT_ASSERT(add.Run());
0042     CPPUNIT_ASSERT(add.error_msg() == "");
0043     CPPUNIT_ASSERT(add.result0() == 579);
0044     CPPUNIT_ASSERT(add.result0_data()[0] == 579);
0045     CPPUNIT_ASSERT(add.result0_data() == add.results()[0]);
0046 
0047     const AddComp& add_const = add;
0048     CPPUNIT_ASSERT(add_const.error_msg() == "");
0049     CPPUNIT_ASSERT(add_const.arg0() == 123);
0050     CPPUNIT_ASSERT(add_const.arg0_data()[0] == 123);
0051     CPPUNIT_ASSERT(add_const.arg1() == 456);
0052     CPPUNIT_ASSERT(add_const.arg1_data()[0] == 456);
0053     CPPUNIT_ASSERT(add_const.result0() == 579);
0054     CPPUNIT_ASSERT(add_const.result0_data()[0] == 579);
0055     CPPUNIT_ASSERT(add_const.result0_data() == add_const.results()[0]);
0056   }
0057 
0058   // run tests that use set_argN_data separately, to avoid accidentally re-using
0059   // non-existent buffers.
0060   {
0061     std::cout << "testing tf add no input buffer" << std::endl;
0062     AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
0063 
0064     int arg_x = 10;
0065     int arg_y = 32;
0066     add.set_arg0_data(&arg_x);
0067     add.set_arg1_data(&arg_y);
0068 
0069     CPPUNIT_ASSERT(add.Run());
0070     CPPUNIT_ASSERT(add.error_msg() == "");
0071     CPPUNIT_ASSERT(add.result0() == 42);
0072     CPPUNIT_ASSERT(add.result0_data()[0] == 42);
0073     CPPUNIT_ASSERT(add.result0_data() == add.results()[0]);
0074   }
0075 }