Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-03-17 11:16:10

0001 #!/usr/bin/env python3
0002 # Copyright (c) Microsoft Corporation. All rights reserved.
0003 # Licensed under the MIT License.
0004 
0005 """
0006 .. _l-example-simple-usage:
0007 Load and predict with ONNX Runtime and a very simple model
0008 ==========================================================
0009 This example demonstrates how to load a model and compute
0010 the output for an input vector. It also shows how to
0011 retrieve the definition of its inputs and outputs.
0012 """
0013 
0014 import onnxruntime as rt
0015 import numpy
0016 from onnxruntime.datasets import get_example
0017 from onnxruntime import datasets
0018 
0019 #########################
0020 # Let's load a very simple model.
0021 # The model is available on github `onnx...test_sigmoid <https://github.com/onnx/onnx/tree/master/onnx/backend/test/data/node/test_sigmoid>`_.
0022 
0023 example1 = get_example("sigmoid.onnx")
0024 sess = rt.InferenceSession(example1, providers=['CPUExecutionProvider'])
0025 
0026 #########################
0027 # Let's see the input name and shape.
0028 
0029 input_name = sess.get_inputs()[0].name
0030 print("input name", input_name)
0031 input_shape = sess.get_inputs()[0].shape
0032 print("input shape", input_shape)
0033 input_type = sess.get_inputs()[0].type
0034 print("input type", input_type)
0035 
0036 #########################
0037 # Let's see the output name and shape.
0038 
0039 output_name = sess.get_outputs()[0].name
0040 print("output name", output_name)  
0041 output_shape = sess.get_outputs()[0].shape
0042 print("output shape", output_shape)
0043 output_type = sess.get_outputs()[0].type
0044 print("output type", output_type)
0045 
0046 #########################
0047 # Let's compute its outputs (or predictions if it is a machine learned model).
0048 
0049 import numpy.random
0050 x = numpy.random.random((3,4,5))
0051 x = x.astype(numpy.float32)
0052 res = sess.run([output_name], {input_name: x})
0053 print(res)