Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-02-14 13:32:43

0001 #include <cppunit/extensions/HelperMacros.h>
0002 
0003 #include "PhysicsTools/MXNet/interface/Predictor.h"
0004 #include "FWCore/ParameterSet/interface/FileInPath.h"
0005 
0006 using namespace mxnet::cpp;
0007 
0008 class testMXNetCppPredictor : public CppUnit::TestFixture {
0009   CPPUNIT_TEST_SUITE(testMXNetCppPredictor);
0010   CPPUNIT_TEST(checkAll);
0011   CPPUNIT_TEST_SUITE_END();
0012 
0013 public:
0014   void checkAll();
0015 };
0016 
0017 CPPUNIT_TEST_SUITE_REGISTRATION(testMXNetCppPredictor);
0018 
0019 void testMXNetCppPredictor::checkAll() {
0020   std::string model_path = edm::FileInPath("PhysicsTools/MXNet/test/data/testmxnet-symbol.json").fullPath();
0021   std::string param_path = edm::FileInPath("PhysicsTools/MXNet/test/data/testmxnet-0000.params").fullPath();
0022 
0023   // load model and params
0024   Block *block = nullptr;
0025   CPPUNIT_ASSERT_NO_THROW(block = new Block(model_path, param_path));
0026   CPPUNIT_ASSERT(block != nullptr);
0027 
0028   // create predictor
0029   Predictor predictor(*block);
0030 
0031   // set input shape
0032   std::vector<std::string> input_names{"data"};
0033   std::vector<std::vector<unsigned>> input_shapes{{1, 3}};
0034   CPPUNIT_ASSERT_NO_THROW(predictor.set_input_shapes(input_names, input_shapes));
0035 
0036   // run predictor
0037   std::vector<std::vector<float>> data{{
0038       1,
0039       2,
0040       3,
0041   }};
0042   std::vector<float> outputs;
0043   CPPUNIT_ASSERT_NO_THROW(outputs = predictor.predict(data));
0044 
0045   // check outputs
0046   CPPUNIT_ASSERT(outputs.size() == 3);
0047   CPPUNIT_ASSERT(outputs.at(0) == 42);
0048   CPPUNIT_ASSERT(outputs.at(1) == 42);
0049   CPPUNIT_ASSERT(outputs.at(2) == 42);
0050 
0051   delete block;
0052 }