我有一个 pyTorch 代码来训练一个模型,该模型应该能够检测产品图像中的占位符图像。我没有自己编写代码,因为我对 CNN 和机器学习非常缺乏经验。
我的老板告诉我计算该模型的f1 分数,我发现其公式为((precision * recall)/(precision + recall)),但我不知道如何获得精确度和召回率。有人能告诉我如何从以下代码中获取这两个参数吗?(抱歉,代码很长,但我真的不知道什么是必要的,什么不是)
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
data_dir = "data"
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "resnet"
# Number …Run Code Online (Sandbox Code Playgroud) 我创建了一个 pyTorch 模型来对图像进行分类。我通过 state_dict 和整个模型保存了一次:
torch.save(model.state_dict(), "model1_statedict")
torch.save(model, "model1_complete")
Run Code Online (Sandbox Code Playgroud)
我该如何使用这些模型?我想用一些图像来检查它们是否良好。
我正在加载模型:
model = torch.load(path_model)
model.eval()
Run Code Online (Sandbox Code Playgroud)
这工作正常,但我不知道如何使用它来预测新图片。