File indexing completed on 2024-04-06 12:23:46
0001 #include <cppunit/extensions/HelperMacros.h>
0002
0003 #include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
0004 #include "FWCore/ParameterSet/interface/FileInPath.h"
0005 #include "HeterogeneousCore/CUDAUtilities/interface/requireDevices.h"
0006
0007 #include <iostream>
0008
0009 using namespace cms::Ort;
0010
0011 class testONNXRuntime : public CppUnit::TestFixture {
0012 CPPUNIT_TEST_SUITE(testONNXRuntime);
0013 CPPUNIT_TEST(checkCPU);
0014 CPPUNIT_TEST(checkGPU);
0015 CPPUNIT_TEST_SUITE_END();
0016
0017 private:
0018 void test(Backend backend);
0019
0020 public:
0021 void checkCPU();
0022 void checkGPU();
0023 };
0024
0025 CPPUNIT_TEST_SUITE_REGISTRATION(testONNXRuntime);
0026
0027 void testONNXRuntime::test(Backend backend) {
0028 std::string model_path = edm::FileInPath("PhysicsTools/ONNXRuntime/test/data/model.onnx").fullPath();
0029 auto session_options = ONNXRuntime::defaultSessionOptions(backend);
0030 ONNXRuntime rt(model_path, &session_options);
0031 for (const unsigned batch_size : {1, 2, 4}) {
0032 FloatArrays input_values{
0033 std::vector<float>(batch_size * 2, 1),
0034 };
0035 FloatArrays outputs;
0036 CPPUNIT_ASSERT_NO_THROW(outputs = rt.run({"X"}, input_values, {}, {"Y"}, batch_size));
0037 CPPUNIT_ASSERT(outputs.size() == 1);
0038 CPPUNIT_ASSERT(outputs[0].size() == batch_size);
0039 for (const auto &v : outputs[0]) {
0040 CPPUNIT_ASSERT(v == 3);
0041 }
0042 }
0043 }
0044
0045 void testONNXRuntime::checkCPU() { test(Backend::cpu); }
0046
0047 void testONNXRuntime::checkGPU() {
0048 if (cms::cudatest::testDevices()) {
0049 test(Backend::cuda);
0050 }
0051 }