小编Sno*_*w24的帖子

计算混淆矩阵的更快方法?

我正在计算我的混淆矩阵,如下所示,用于图像语义分割,这是一种非常冗长的方法:

def confusion_matrix(preds, labels, conf_m, sample_size):
    preds = normalize(preds,0.9) # returns [0,1] tensor
    preds = preds.flatten()
    labels = labels.flatten()
    for i in range(len(preds)):
        if preds[i]==1 and labels[i]==1:
            conf_m[0,0] += 1/(len(preds)*sample_size) # TP
        elif preds[i]==1 and labels[i]==0:
            conf_m[0,1] += 1/(len(preds)*sample_size) # FP
        elif preds[i]==0 and labels[i]==0:
            conf_m[1,0] += 1/(len(preds)*sample_size) # TN
        elif preds[i]==0 and labels[i]==1:
            conf_m[1,1] += 1/(len(preds)*sample_size) # FN 
    return conf_m
Run Code Online (Sandbox Code Playgroud)

在预测循环中:

conf_m = torch.zeros(2,2) # two classes (object or no-object)
for img,label in enumerate(data):
    ...
    out = Net(img)
    conf_m …
Run Code Online (Sandbox Code Playgroud)

metrics python-3.x confusion-matrix pytorch

3
推荐指数
1
解决办法
1260
查看次数

标签 统计

confusion-matrix ×1

metrics ×1

python-3.x ×1

pytorch ×1