File indexing completed on 2024-04-06 12:23:46
0001
0002
0003
0004
0005
0006
0007
0008 #include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
0009
0010 #include "FWCore/Utilities/interface/Exception.h"
0011 #include "FWCore/Utilities/interface/thread_safety_macros.h"
0012 #include <algorithm>
0013 #include <cassert>
0014 #include <functional>
0015 #include <iostream>
0016 #include <memory>
0017 #include <numeric>
0018
0019 namespace cms::Ort {
0020
0021 using namespace ::Ort;
0022
0023 const Env ONNXRuntime::env_(ORT_LOGGING_LEVEL_ERROR, "");
0024
0025 ONNXRuntime::ONNXRuntime(const std::string& model_path, const SessionOptions* session_options) {
0026
0027 if (session_options) {
0028 session_ = std::make_unique<Session>(env_, model_path.c_str(), *session_options);
0029 } else {
0030 session_ = std::make_unique<Session>(env_, model_path.c_str(), defaultSessionOptions());
0031 }
0032 AllocatorWithDefaultOptions allocator;
0033
0034
0035 size_t num_input_nodes = session_->GetInputCount();
0036 input_node_strings_.resize(num_input_nodes);
0037 input_node_names_.resize(num_input_nodes);
0038 input_node_dims_.clear();
0039
0040 for (size_t i = 0; i < num_input_nodes; i++) {
0041
0042 std::string input_name(session_->GetInputNameAllocated(i, allocator).get());
0043 input_node_strings_[i] = input_name;
0044 input_node_names_[i] = input_node_strings_[i].c_str();
0045
0046
0047 auto type_info = session_->GetInputTypeInfo(i);
0048 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0049
0050 input_node_dims_[input_name] = tensor_info.GetShape();
0051 }
0052
0053 size_t num_output_nodes = session_->GetOutputCount();
0054 output_node_strings_.resize(num_output_nodes);
0055 output_node_names_.resize(num_output_nodes);
0056 output_node_dims_.clear();
0057
0058 for (size_t i = 0; i < num_output_nodes; i++) {
0059
0060 std::string output_name(session_->GetOutputNameAllocated(i, allocator).get());
0061 output_node_strings_[i] = output_name;
0062 output_node_names_[i] = output_node_strings_[i].c_str();
0063
0064
0065 auto type_info = session_->GetOutputTypeInfo(i);
0066 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0067 output_node_dims_[output_name] = tensor_info.GetShape();
0068
0069
0070 output_node_dims_[output_name].at(0) = -1;
0071 }
0072 }
0073
0074 ONNXRuntime::~ONNXRuntime() {}
0075
0076 SessionOptions ONNXRuntime::defaultSessionOptions(Backend backend) {
0077 SessionOptions sess_opts;
0078 sess_opts.SetIntraOpNumThreads(1);
0079 if (backend == Backend::cuda) {
0080
0081 OrtCUDAProviderOptions options;
0082 sess_opts.AppendExecutionProvider_CUDA(options);
0083 }
0084 return sess_opts;
0085 }
0086
0087 FloatArrays ONNXRuntime::run(const std::vector<std::string>& input_names,
0088 FloatArrays& input_values,
0089 const std::vector<std::vector<int64_t>>& input_shapes,
0090 const std::vector<std::string>& output_names,
0091 int64_t batch_size) const {
0092 assert(input_names.size() == input_values.size());
0093 assert(input_shapes.empty() || input_names.size() == input_shapes.size());
0094 assert(batch_size > 0);
0095
0096
0097 std::vector<Value> input_tensors;
0098 auto memory_info = MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
0099 for (const auto& name : input_node_strings_) {
0100 auto iter = std::find(input_names.begin(), input_names.end(), name);
0101 if (iter == input_names.end()) {
0102 throw cms::Exception("RuntimeError") << "Input " << name << " is not provided!";
0103 }
0104 auto input_pos = iter - input_names.begin();
0105 auto value = input_values.begin() + input_pos;
0106 std::vector<int64_t> input_dims;
0107 if (input_shapes.empty()) {
0108 input_dims = input_node_dims_.at(name);
0109 input_dims[0] = batch_size;
0110 } else {
0111 input_dims = input_shapes[input_pos];
0112
0113 if (input_dims[0] != batch_size) {
0114 throw cms::Exception("RuntimeError") << "The first element of `input_shapes` (" << input_dims[0]
0115 << ") does not match the given `batch_size` (" << batch_size << ")";
0116 }
0117 }
0118 auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1, std::multiplies<int64_t>());
0119 if (expected_len != (int64_t)value->size()) {
0120 throw cms::Exception("RuntimeError")
0121 << "Input array " << name << " has a wrong size of " << value->size() << ", expected " << expected_len;
0122 }
0123 auto input_tensor =
0124 Value::CreateTensor<float>(memory_info, value->data(), value->size(), input_dims.data(), input_dims.size());
0125 assert(input_tensor.IsTensor());
0126 input_tensors.emplace_back(std::move(input_tensor));
0127 }
0128
0129
0130 std::vector<const char*> run_output_node_names;
0131 if (output_names.empty()) {
0132 run_output_node_names = output_node_names_;
0133 } else {
0134 for (const auto& name : output_names) {
0135 run_output_node_names.push_back(name.c_str());
0136 }
0137 }
0138
0139
0140 auto output_tensors = session_->Run(RunOptions{nullptr},
0141 input_node_names_.data(),
0142 input_tensors.data(),
0143 input_tensors.size(),
0144 run_output_node_names.data(),
0145 run_output_node_names.size());
0146
0147
0148 FloatArrays outputs;
0149 for (auto& output_tensor : output_tensors) {
0150 assert(output_tensor.IsTensor());
0151
0152
0153 auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo();
0154 auto length = tensor_info.GetElementCount();
0155
0156 auto floatarr = output_tensor.GetTensorMutableData<float>();
0157 outputs.emplace_back(floatarr, floatarr + length);
0158 }
0159 assert(outputs.size() == run_output_node_names.size());
0160
0161 return outputs;
0162 }
0163
0164 const std::vector<std::string>& ONNXRuntime::getOutputNames() const {
0165 if (session_) {
0166 return output_node_strings_;
0167 } else {
0168 throw cms::Exception("RuntimeError") << "Needs to call createSession() first before getting the output names!";
0169 }
0170 }
0171
0172 const std::vector<int64_t>& ONNXRuntime::getOutputShape(const std::string& output_name) const {
0173 auto iter = output_node_dims_.find(output_name);
0174 if (iter == output_node_dims_.end()) {
0175 throw cms::Exception("RuntimeError") << "Output name " << output_name << " is invalid!";
0176 } else {
0177 return iter->second;
0178 }
0179 }
0180
0181 }