Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-02-14 13:33:24

0001 /*
0002  * TensorFlow AOT test
0003  * Based on TensorFlow 2.1.
0004  * For more info, see https://gitlab.cern.ch/mrieger/CMSSW-DNN.
0005  *
0006  * Author: Marcel Rieger
0007  */
0008 
0009 #include <cppunit/extensions/HelperMacros.h>
0010 #include <stdexcept>
0011 
0012 #include "testAOT_add/header.h"
0013 
0014 using AddComp = testAOT_add;
0015 
0016 class testAOT : public CppUnit::TestFixture {
0017   CPPUNIT_TEST_SUITE(testAOT);
0018   CPPUNIT_TEST(checkAll);
0019   CPPUNIT_TEST_SUITE_END();
0020 
0021 public:
0022   void checkAll();
0023 };
0024 
0025 CPPUNIT_TEST_SUITE_REGISTRATION(testAOT);
0026 
0027 void testAOT::checkAll() {
0028   {
0029     std::cout << "testing tf add" << std::endl;
0030     AddComp add;
0031 
0032     add.arg0() = 1;
0033     add.arg1() = 2;
0034     CPPUNIT_ASSERT(add.Run());
0035     CPPUNIT_ASSERT(add.error_msg() == "");
0036     CPPUNIT_ASSERT(add.result0() == 3);
0037     CPPUNIT_ASSERT(add.result0_data()[0] == 3);
0038     CPPUNIT_ASSERT(add.result0_data() == add.results()[0]);
0039 
0040     add.arg0_data()[0] = 123;
0041     add.arg1_data()[0] = 456;
0042     CPPUNIT_ASSERT(add.Run());
0043     CPPUNIT_ASSERT(add.error_msg() == "");
0044     CPPUNIT_ASSERT(add.result0() == 579);
0045     CPPUNIT_ASSERT(add.result0_data()[0] == 579);
0046     CPPUNIT_ASSERT(add.result0_data() == add.results()[0]);
0047 
0048     const AddComp& add_const = add;
0049     CPPUNIT_ASSERT(add_const.error_msg() == "");
0050     CPPUNIT_ASSERT(add_const.arg0() == 123);
0051     CPPUNIT_ASSERT(add_const.arg0_data()[0] == 123);
0052     CPPUNIT_ASSERT(add_const.arg1() == 456);
0053     CPPUNIT_ASSERT(add_const.arg1_data()[0] == 456);
0054     CPPUNIT_ASSERT(add_const.result0() == 579);
0055     CPPUNIT_ASSERT(add_const.result0_data()[0] == 579);
0056     CPPUNIT_ASSERT(add_const.result0_data() == add_const.results()[0]);
0057   }
0058 
0059   // run tests that use set_argN_data separately, to avoid accidentally re-using
0060   // non-existent buffers.
0061   {
0062     std::cout << "testing tf add no input buffer" << std::endl;
0063     AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
0064 
0065     int arg_x = 10;
0066     int arg_y = 32;
0067     add.set_arg0_data(&arg_x);
0068     add.set_arg1_data(&arg_y);
0069 
0070     CPPUNIT_ASSERT(add.Run());
0071     CPPUNIT_ASSERT(add.error_msg() == "");
0072     CPPUNIT_ASSERT(add.result0() == 42);
0073     CPPUNIT_ASSERT(add.result0_data()[0] == 42);
0074     CPPUNIT_ASSERT(add.result0_data() == add.results()[0]);
0075   }
0076 }