如何将numpy.argpartition的输出应用于二维数组?

dre*_*cko 7 python arrays indexing performance numpy

我有一个较大的2d numpy数组,我想提取每行的最低10个元素及其索引.由于我的数组很大,我宁愿不对整个数组进行排序.

我听说过这个argpartition()函数,我可以用它获得最低10个元素的索引:

top10indexes = np.argpartition(myBigArray,10)[:,:10]
Run Code Online (Sandbox Code Playgroud)

请注意,argpartition()默认情况下分区轴为-1,这就是我想要的.此处的结果与myBigArray具有相同的形状,其中包含各个行的索引,使得前10个索引指向10个最低值.

我现在如何提取myBigArray与这些索引相对应的元素?

明显的花哨索引喜欢myBigArray[top10indexes]myBigArray[:,top10indexes]做一些完全不同的事情.我还可以使用列表推导,例如:

array([row[idxs] for row,idxs in zip(myBigArray,top10indexes)])
Run Code Online (Sandbox Code Playgroud)

但这会导致性能损失迭代numpy行并将结果转换回数组.

nb:我可以np.partition()用来获取值,它们甚至可能对应于索引(或者可能不是......),但如果我可以避免它,我不想再进行两次分区.

Sau*_*tro 8

您可以通过执行以下操作来避免使用拼合副本以及提取所有值的需要:

num = 10
top = np.argpartition(myBigArray, num, axis=1)[:, :num]
myBigArray[np.arange(myBigArray.shape[0])[:, None], top]
Run Code Online (Sandbox Code Playgroud)

对于NumPy> = 1.9.0,这将是非常有效和可比的np.take().

  • 我使用`flatten()`删除了我的答案.我找出了为什么它不起作用,但没有看到任何简单的方法来解决它,而没有有效地制作你的更复杂的版本! (2认同)