我正在使用 fastai 库 ( fast.ai ) 来训练图像分类器。fastai创建的模型实际上是pytorch模型。
type(model)
<class 'torch.nn.modules.container.Sequential'>
Run Code Online (Sandbox Code Playgroud)
现在,我想使用 pytorch 中的这个模型进行推理。到目前为止,这是我的代码:
torch.save(model,"./torch_model_v1")
the_model = torch.load("./torch_model_v1")
the_model.eval() # shows the entire network architecture
Run Code Online (Sandbox Code Playgroud)
根据此处显示的示例:http://pytorch.org/tutorials/beginner/data_loading_tutorial.html#sphx-glr-beginner-data-loading-tutorial-py,我知道我需要编写自己的数据加载类将覆盖 Dataset 类中的一些函数。但我不清楚的是我需要在测试时应用哪些转换?特别是,如何在测试时标准化图像?
另一个问题:我在 pytorch 中保存和加载模型的方法好吗?我在此处的教程中读到:http://pytorch.org/docs/master/notes/serialization.html不推荐我使用的方法。但原因尚不清楚。
只是澄清一下:the_model.eval()
不仅打印架构,还将模型设置为评估模式。
特别是,如何在测试时标准化图像?
这取决于您的型号。例如,对于模块,您必须以这种方式torchvision
标准化输入。
关于如何保存/加载模型,torch.save
/ torch.load
“将对象保存/加载到磁盘文件。”
因此,如果您保存the_model
,它将保存整个模型对象,包括其架构定义和其他一些内部方面。如果保存the_model.state_dict()
,它将保存仅包含模型状态(即参数和缓冲区)的字典。保存模型可能会以多种方式破坏代码,因此首选方法是仅保存和加载模型状态。但是,我不确定 fast.ai“模型文件”实际上是完整模型还是模型的状态。您必须检查此项,以便可以正确加载它。
归档时间: |
|
查看次数: |
12286 次 |
最近记录: |