如何在 PyTorch 中查找混淆矩阵并为图像分类器绘制它

Sar*_*han 5 python image-processing computer-vision pytorch transfer-learning

基本上这是 VGG-16 模型,我已经执行了迁移学习并微调了模型,我在 2 周前训练了这个模型,并发现了测试和训练的准确性,但现在我也需要模型的分类准确性,我是试图找出混淆矩阵并也想绘制该矩阵。培训代码:

# Training the model again from the last CNN Block to The End of the Network
dataset = 'C:\\Users\\Sara Latif Khan\\OneDrive\\Desktop\\FYP_\\Scene15\\15-Scene'
model = model.to(device)
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()))

#Training Fixed Feature Extractor for 15 epochs
num_epochs = 5
batch_loss = 0
cum_epoch_loss = 0 #cumulative loss for each batch

for e in range(num_epochs):
    cum_epoch_loss = 0
  
    for batch, (images, labels) in enumerate(trainloader,1):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logps = model(images)
        loss = criterion(logps, labels)
        loss.backward()
        optimizer.step()
    
        batch_loss += loss.item()
        print(f'Epoch({e}/{num_epochs} : Batch number({batch}/{len(trainloader)}) : Batch loss : {loss.item()}')
        torch.save(model, dataset+'_model_'+str(e)+'.pt')
    
print(f'Training loss : {batch_loss/len(trainloader)}')
Run Code Online (Sandbox Code Playgroud)

这是我用来根据测试加载器的数据检查模型准确性的代码。

model. to('cpu')

model.eval()
with torch.no_grad():
    num_correct = 0
    total = 0
    
    #set_trace ()
    for batch, (images,labels) in enumerate(testloader,1):
        
        logps = model(images)
        output = torch.exp(logps)
        
        pred = torch.argmax(output,1)
        total += labels.size(0)
        
        num_correct += (pred==labels).sum().item()
        print(f'Batch ({batch} / {len(testloader)})')
        
        # to check the accuracy of model on 5 batches
        # if batch == 5:
            # break
            
    print(f'Accuracy of the model on {total} test images: {num_correct * 100 / total }% ')  
Run Code Online (Sandbox Code Playgroud)

接下来,我需要找到模型的类别准确性。我正在研究 Jupyter Notebook。我应该重新加载保存的模型并找到 cm 或者什么是适当的方法。

The*_*fer 3

您必须保存测试集的所有预测和目标。

predictions, targets = [], []
for images, labels in testloader:
    logps = model(images)
    output = torch.exp(logps)
    pred = torch.argmax(output, 1)

    # convert to numpy arrays
    pred = pred.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    
    for i in range(len(pred)):
        predictions.append(pred[i])
        targets.append(labels[i])
Run Code Online (Sandbox Code Playgroud)

现在您已存储测试集的所有预测和实际目标。下一步是创建混淆矩阵。我想我可以给你我经常使用的功能:

def create_confusion_matrix(y_true, y_pred, classes):
    """ creates and plots a confusion matrix given two list (targets and predictions)
    :param list y_true: list of all targets (in this case integers bc. they are indices)
    :param list y_pred: list of all predictions (in this case one-hot encoded)
    :param dict classes: a dictionary of the countries with they index representation
    """

    amount_classes = len(classes)

    confusion_matrix = np.zeros((amount_classes, amount_classes))
    for idx in range(len(y_true)):
        target = y_true[idx][0]

        output = y_pred[idx]
        output = list(output).index(max(output))

        confusion_matrix[target][output] += 1

    fig, ax = plt.subplots(1)

    ax.matshow(confusion_matrix)
    ax.set_xticks(np.arange(len(list(classes.keys()))))
    ax.set_yticks(np.arange(len(list(classes.keys()))))

    ax.set_xticklabels(list(classes.keys()))
    ax.set_yticklabels(list(classes.keys()))

    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    plt.show()
Run Code Online (Sandbox Code Playgroud)

所以 y_true 是所有目标, y_pred 是所有预测,并且classes是将标签映射到实际类名的字典,例如:

classes = {"dog": [1, 0], "cat": [0, 1]}
Run Code Online (Sandbox Code Playgroud)

然后只需调用:

create_confusion_matrix(targets, predictions, classes)
Run Code Online (Sandbox Code Playgroud)

也许您需要对其进行一些调整以适应您的代码,但我希望这对您有用。:)