我们如何将 .pth 模型转换为 .pb 文件?

Raf*_*ael 7 tensorflow pytorch

我已经通过使用 pytorch 获得了完整的模型,但是我想将 .pth 文件转换为 .pb,它可以在 Tensorflow 中使用。有没有人有一些想法?

Dis*_*ani 7

您可以使用ONNX:开放神经网络交换格式

要将.pth文件转换为.pb首先,您需要将 PyTorch 中定义的模型导出到 ONNX,然后将 ONNX 模型导入 Tensorflow(PyTorch => ONNX => Tensorflow)

这是MNISTModel的一个例子来使用ONNX转换一个PyTorch模型Tensorflowonnx /教程

将训练好的模型保存到文件

torch.save(model.state_dict(), 'output/mnist.pth')
Run Code Online (Sandbox Code Playgroud)

从文件加载训练好的模型

trained_model = Net()
trained_model.load_state_dict(torch.load('output/mnist.pth'))

# Export the trained model to ONNX
dummy_input = Variable(torch.randn(1, 1, 28, 28)) # one black and white 28 x 28 picture will be the input to the model
torch.onnx.export(trained_model, dummy_input, "output/mnist.onnx")
Run Code Online (Sandbox Code Playgroud)

加载 ONNX 文件

model = onnx.load('output/mnist.onnx')

# Import the ONNX model to Tensorflow
tf_rep = prepare(model)
Run Code Online (Sandbox Code Playgroud)

将 Tensorflow 模型保存到文件中

tf_rep.export_graph('output/mnist.pb')
Run Code Online (Sandbox Code Playgroud)

@tsveti_iko在评论中 指出

注意:prepare()内置于 中onnx-tf,因此您首先需要像这样通过控制台安装它pip install onnx-tf,然后像这样在代码中导入它:import onnx from onnx_tf.backend import prepare然后您最终可以按照答案中的描述使用它。

  • 注意:`prepare()` 内置在 `onnx-tf` 中,因此您首先需要通过控制台安装它,如 `pip install onnx-tf`,然后将其导入到代码中,如下所示:` import onnx` `from onnx_tf.backend importprepare` 然后你终于可以按照答案中的描述使用它了。 (2认同)