如何从 numpy 数组中获取前 K 个值的索引

Des*_*wal 2 python numpy keras pytorch

假设我有来自 Pytorch 或 Keras 预测的概率,结果是使用 softmax 函数

from scipy.special import softmax
probs = softmax(np.random.randn(20,10),1) # 20 instances and 10 class probabilities
probs
Run Code Online (Sandbox Code Playgroud)

我想从这个 numpy 数组中找到前 5 个索引。我想做的就是对结果运行一个循环,如下所示:

for index in top_5_indices:
    if index in result:
        print('Found')
Run Code Online (Sandbox Code Playgroud)

如果我的结果进入前 5 名,我就会得到结果。

Pytorchtop-k功能,我已经看到了numpy.argpartition,但我不知道如何完成这个?

Vik*_*ova 6

numpy 中的 argpartition(a, k) 函数围绕第 k 个最小元素重新排列输入数组 a 的索引,以便较小元素的所有索引最终位于左侧,较大元素的所有索引最终位于右侧。不需要对所有元素进行排序可以节省时间:argpartition 需要 O(n) 时间,而 argsort 需要 O(n log n) 时间。

所以你可以得到 5 个最大元素的索引,如下所示:

np.argpartition(probs,-5)[-5:]
Run Code Online (Sandbox Code Playgroud)