Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:23:46

0001 /*
0002  * ONNXRuntime.h
0003  *
0004  * A convenience wrapper of the ONNXRuntime C++ API.
0005  * Based on https://github.com/microsoft/onnxruntime/blob/master/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/CXX_Api_Sample.cpp.
0006  *
0007  *  Created on: Jun 28, 2019
0008  *      Author: hqu
0009  */
0010 
0011 #ifndef PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_
0012 #define PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_
0013 
0014 #include <vector>
0015 #include <map>
0016 #include <string>
0017 #include <memory>
0018 
0019 #include "onnxruntime/onnxruntime_cxx_api.h"
0020 
0021 namespace cms::Ort {
0022 
0023   typedef std::vector<std::vector<float>> FloatArrays;
0024 
0025   enum class Backend {
0026     cpu,
0027     cuda,
0028   };
0029 
0030   class ONNXRuntime {
0031   public:
0032     ONNXRuntime(const std::string& model_path, const ::Ort::SessionOptions* session_options = nullptr);
0033     ONNXRuntime(const ONNXRuntime&) = delete;
0034     ONNXRuntime& operator=(const ONNXRuntime&) = delete;
0035     ~ONNXRuntime();
0036 
0037     static ::Ort::SessionOptions defaultSessionOptions(Backend backend = Backend::cpu);
0038 
0039     // Run inference and get outputs
0040     // input_names: list of the names of the input nodes.
0041     // input_values: list of input arrays for each input node. The order of `input_values` must match `input_names`.
0042     // input_shapes: list of `int64_t` arrays specifying the shape of each input node. Can leave empty if the model does not have dynamic axes.
0043     // output_names: names of the output nodes to get outputs from. Empty list means all output nodes.
0044     // batch_size: number of samples in the batch. Each array in `input_values` must have a shape layout of (batch_size, ...).
0045     // Returns: a std::vector<std::vector<float>>, with the order matched to `output_names`.
0046     // When `output_names` is empty, will return all outputs ordered as in `getOutputNames()`.
0047     FloatArrays run(const std::vector<std::string>& input_names,
0048                     FloatArrays& input_values,
0049                     const std::vector<std::vector<int64_t>>& input_shapes = {},
0050                     const std::vector<std::string>& output_names = {},
0051                     int64_t batch_size = 1) const;
0052 
0053     // Get a list of names of all the output nodes
0054     const std::vector<std::string>& getOutputNames() const;
0055 
0056     // Get the shape of a output node
0057     // The 0th dim depends on the batch size, therefore is set to -1
0058     const std::vector<int64_t>& getOutputShape(const std::string& output_name) const;
0059 
0060   private:
0061     static const ::Ort::Env env_;
0062     std::unique_ptr<::Ort::Session> session_;
0063 
0064     std::vector<std::string> input_node_strings_;
0065     std::vector<const char*> input_node_names_;
0066     std::map<std::string, std::vector<int64_t>> input_node_dims_;
0067 
0068     std::vector<std::string> output_node_strings_;
0069     std::vector<const char*> output_node_names_;
0070     std::map<std::string, std::vector<int64_t>> output_node_dims_;
0071   };
0072 
0073 }  // namespace cms::Ort
0074 
0075 #endif /* PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_ */