File indexing completed on 2024-04-06 12:23:46
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 #ifndef PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_
0012 #define PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_
0013
0014 #include <vector>
0015 #include <map>
0016 #include <string>
0017 #include <memory>
0018
0019 #include "onnxruntime/onnxruntime_cxx_api.h"
0020
0021 namespace cms::Ort {
0022
0023 typedef std::vector<std::vector<float>> FloatArrays;
0024
0025 enum class Backend {
0026 cpu,
0027 cuda,
0028 };
0029
0030 class ONNXRuntime {
0031 public:
0032 ONNXRuntime(const std::string& model_path, const ::Ort::SessionOptions* session_options = nullptr);
0033 ONNXRuntime(const ONNXRuntime&) = delete;
0034 ONNXRuntime& operator=(const ONNXRuntime&) = delete;
0035 ~ONNXRuntime();
0036
0037 static ::Ort::SessionOptions defaultSessionOptions(Backend backend = Backend::cpu);
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047 FloatArrays run(const std::vector<std::string>& input_names,
0048 FloatArrays& input_values,
0049 const std::vector<std::vector<int64_t>>& input_shapes = {},
0050 const std::vector<std::string>& output_names = {},
0051 int64_t batch_size = 1) const;
0052
0053
0054 const std::vector<std::string>& getOutputNames() const;
0055
0056
0057
0058 const std::vector<int64_t>& getOutputShape(const std::string& output_name) const;
0059
0060 private:
0061 static const ::Ort::Env env_;
0062 std::unique_ptr<::Ort::Session> session_;
0063
0064 std::vector<std::string> input_node_strings_;
0065 std::vector<const char*> input_node_names_;
0066 std::map<std::string, std::vector<int64_t>> input_node_dims_;
0067
0068 std::vector<std::string> output_node_strings_;
0069 std::vector<const char*> output_node_names_;
0070 std::map<std::string, std::vector<int64_t>> output_node_dims_;
0071 };
0072
0073 }
0074
0075 #endif