从 Python 中的 ONNX 模型获取预测
我找不到任何人向外行解释如何将 onnx 模型加载到 python 脚本中,然后在输入图像时使用该模型进行预测。我能找到的只是这些代码行:
sess = rt.InferenceSession("onnx_model.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred = sess.run([label_name], {input_name: X.astype(np.float32)})[0]
但我不知道这意味着什么。无论我看到哪里,每个人似乎都已经知道它们的意思,所以没有人解释它。如果我可以运行这段代码,那将是一回事,但我不能。它给了我这个错误:
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid rank for input: Input3 Got: 2 Expected: 4 Please fix either the inputs or the model.
所以我需要真正知道这些东西的含义,这样我才能弄清楚如何修复错误。有懂行的请解释一下吗?
I can't find anyone who explains to a layman how to load an onnx model into a python script, then use that model to make a prediction when fed an image. All I could find were these lines of code:
sess = rt.InferenceSession("onnx_model.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred = sess.run([label_name], {input_name: X.astype(np.float32)})[0]
But I don't know what any of that means. And everywhere I look, everybody already seems to know what they mean, so nobody's explaining it. That would be one thing if I could just run this code, but I can't. It gives me this error:
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid rank for input: Input3 Got: 2 Expected: 4 Please fix either the inputs or the model.
So I need to actually know what those things mean so I can figure out how to fix the error. Will someone knowledgeable please explain?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
让我们首先检查您提供的代码,以使一切清楚。
该行将模型加载到会话对象中。这意味着模型中使用的层、函数和权重已准备好执行推理。
get_inputs 和 get_outputs 这两个方法各自检索有关模型的一些元信息,即模型期望的输入以及模型可以提供的输出。在这些行中的元信息中,只有第一个输入和第一个输入。实际使用输出,并且仅获取名称并将其保存到变量中。
对于最后一行,让我们逐个解决该问题。
这对模型进行推理,之后我们将检查此方法的输入,但现在,输出是不同输出的列表。这些输出都是 numpy 数组。在这种情况下,仅使用此列表中的第一个输出,并将其保存到
pred
变量中。这些是
sess.run
的输入。第一个是您想要由会话计算的输出名称的列表。第二个参数是一个字典,其中每个输入的名称映射到 numpy 数组。这些数组的维度应与模型创建期间提供的数组的维度相同。同样,这些数组的类型也应该与模型创建期间使用的类型相匹配。您遇到的错误似乎表明提供的数组没有预期的尺寸。这些预期的维度数量似乎是 4。
为了清楚地了解输入数组的确切形状和数据类型应该是什么,可以使用可视化工具,例如 Netron
Let's first start by going over the code you provided, to make everything clear.
This line loads the model into a session object. This means that the layers, functions and weights used in the model are made ready to perform inferences.
The two methods
get_inputs
andget_outputs
each retrieve some meta information about the model, that being what inputs the model expects, and what outputs it can provide. Off of this meta information in these lines, only the first input & output is actually used, and off of these, only the name is being gotten, and saved into variables.For the last line, let's tackle that part by part.
This performs a inference on the model, we'll go over the inputs to this method after this, but for now, the output is a list of different outputs. These outputs are each numpy arrays. In this case only the first output in this list is being used, and saved to the
pred
variableThese are the inputs to
sess.run
. The fist is a list of names of outputs that you want to be computed by the session. The second argument is a dict, where each input's name maps to numpy arrays. These arrays are are expected to be of the same dimension as the ones supplied during creation of the model. Similarly the types of these arrays should also match the types used during creation of the model.The error you encountered seems to indicate that the supplied array doesn't have the expected dimensions. These intended amount of dimensions seems to be 4.
To gain clarity about what the exact shape and data type of the input array should be, there are visualization tools, like Netron