如何使用 PyTorch 在语义分割中获得 top k 的准确率?

JJT*_*JTT 6 python computer-vision image-segmentation pytorch semantic-segmentation

你如何计算语义分割中的 top k 准确率?在分类中,我们可以将topk准确度计算为:

correct = output.eq(gt.view(1, -1).expand_as(output))
Run Code Online (Sandbox Code Playgroud)

Sha*_*hai 3

您正在寻找torch.topk计算沿维度的前 k 个值的函数。
第二个输出torch.topk是“arg top k”:顶部值的 k 索引。

以下是如何在语义分割的上下文中使用它:假设您有形状为- - (dtype= ) 的
地面实况预测张量。您的模型预测形状的 每像素类别- - - ,其中是类别数(包括“背景”)。这些逻辑是函数将其转换为类概率之前的“原始”预测。由于我们只查看顶部,因此预测是“原始”还是“概率”并不重要。ybhwtorch.int64
logitsbchwc softmaxk

# compute the top k predicted classes, per pixel:
_, tk = torch.topk(logits, k, dim=1)
# you now have k predictions per pixel, and you want that one of them will match the true labels y:
correct_pixels = torch.eq(y[:, None, ...], tk).any(dim=1)
# take the mean of correct_pixels to get the overall average top-k accuracy:
top_k_acc = correct_pixels.mean()  
Run Code Online (Sandbox Code Playgroud)

请注意,此方法不考虑“忽略”像素。这可以通过对上面的代码稍加修改来完成:

valid = y != ignore_index
top_k_acc = correct_pixels[valid].mean()
Run Code Online (Sandbox Code Playgroud)