计算pytorch张量中的唯一元素

sac*_*ruk 7 pytorch

假设我有以下张量:y = torch.randint(0, 3, (10,))。你会如何计算其中的 0、1 和 2?

我能想到的唯一方法是使用,collections.Counter(y)但想知道是否有更“pytorch”的方法来做到这一点。例如,一个用例是构建用于预测的混淆矩阵。

Iva*_*van 10

您可以使用torch.unique以下选项return_counts

>>> x = torch.randint(0, 3, (10,))
tensor([1, 1, 0, 2, 1, 0, 1, 1, 2, 1])

>>> x.unique(return_counts=True)
(tensor([0, 1, 2]), tensor([2, 6, 2]))
Run Code Online (Sandbox Code Playgroud)