我正在编写一个 python 脚本,它将任何深度学习模型从流行框架(TensorFlow、Keras、PyTorch)转换为 ONNX 格式。目前,我使用tf2onnx进行tensorflow,使用keras2onnx进行 keras 到 ONNX 的转换,这些都有效。
\n现在 PyTorch 集成了 ONNX 支持,因此我可以直接从 PyTorch 保存 ONNX 模型。但问题是我需要该模型的输入张量形状,以便将其保存为 ONNX 格式。正如您可能已经猜到的那样,我正在编写此脚本来转换未知的深度学习模型。
\n这是 PyTorch 的 ONNX 转换教程。那里写着:
\n\n\n限制\xc2\xb6\nONNX 导出器是基于跟踪的导出器,这意味着它通过执行一次模型并导出在此运行期间实际运行的运算符来进行操作。这意味着,如果您的模型是动态的,例如,根据输入数据更改行为,则导出将\xe2\x80\x99 不准确。
\n
\n\n类似地,跟踪可能仅对特定的输入大小有效(这就是我们在跟踪时需要显式输入的原因之一)。大多数运算符导出与大小无关的版本,并且应该适用于不同的批量大小或输入大小。我们建议检查模型跟踪并确保跟踪的运算符看起来合理。
\n
我正在使用的代码片段是这样的:
\nimport torch\n\ndef convert_pytorch2onnx(self):\n """pytorch -> onnx"""\n\n model = torch.load(self._model_file_path)\n\n # Don\'t know how to get this INPUT_SHAPE\n dummy_input = torch.randn(INPUT_SHAPE)\n torch.onnx.export(model, dummy_input, self._onnx_file_path)\n return\nRun Code Online (Sandbox Code Playgroud)\n 那么我如何知道未知 PyTorch 模型的输入张量的 INPUT_SHAPE 呢?或者有没有其他方法可以将PyTorch模型转换为ONNX?
\n