Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:23:46

0001 #include <cppunit/extensions/HelperMacros.h>
0002 
0003 #include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
0004 #include "FWCore/ParameterSet/interface/FileInPath.h"
0005 #include "HeterogeneousCore/CUDAUtilities/interface/requireDevices.h"
0006 
0007 #include <iostream>
0008 
0009 using namespace cms::Ort;
0010 
0011 class testONNXRuntime : public CppUnit::TestFixture {
0012   CPPUNIT_TEST_SUITE(testONNXRuntime);
0013   CPPUNIT_TEST(checkCPU);
0014   CPPUNIT_TEST(checkGPU);
0015   CPPUNIT_TEST_SUITE_END();
0016 
0017 private:
0018   void test(Backend backend);
0019 
0020 public:
0021   void checkCPU();
0022   void checkGPU();
0023 };
0024 
0025 CPPUNIT_TEST_SUITE_REGISTRATION(testONNXRuntime);
0026 
0027 void testONNXRuntime::test(Backend backend) {
0028   std::string model_path = edm::FileInPath("PhysicsTools/ONNXRuntime/test/data/model.onnx").fullPath();
0029   auto session_options = ONNXRuntime::defaultSessionOptions(backend);
0030   ONNXRuntime rt(model_path, &session_options);
0031   for (const unsigned batch_size : {1, 2, 4}) {
0032     FloatArrays input_values{
0033         std::vector<float>(batch_size * 2, 1),
0034     };
0035     FloatArrays outputs;
0036     CPPUNIT_ASSERT_NO_THROW(outputs = rt.run({"X"}, input_values, {}, {"Y"}, batch_size));
0037     CPPUNIT_ASSERT(outputs.size() == 1);
0038     CPPUNIT_ASSERT(outputs[0].size() == batch_size);
0039     for (const auto &v : outputs[0]) {
0040       CPPUNIT_ASSERT(v == 3);
0041     }
0042   }
0043 }
0044 
0045 void testONNXRuntime::checkCPU() { test(Backend::cpu); }
0046 
0047 void testONNXRuntime::checkGPU() {
0048   if (cms::cudatest::testDevices()) {
0049     test(Backend::cuda);
0050   }
0051 }