为了找到最小值的索引,我可以使用argmin:
import numpy as np
A = np.array([1, 7, 9, 2, 0.1, 17, 17, 1.5])
print A.argmin() # 4 because A[4] = 0.1
Run Code Online (Sandbox Code Playgroud)
但是如何找到k-最小值的指数?
我正在寻找类似的东西:
print A.argmin(numberofvalues=3)
# [4, 0, 7] because A[4] <= A[0] <= A[7] <= all other A[i]
Run Code Online (Sandbox Code Playgroud)
注意:在我的用例A中有大约10 000到100 000个值,我只对k = 10个最小值的索引感兴趣.k永远不会超过10.
unu*_*tbu 70
使用np.argpartition.它不会对整个数组进行排序.它只保证kth元素处于排序位置,所有较小的元素将在它之前移动.因此,第一个k元素将是k个最小元素.
import numpy as np
A = np.array([1, 7, 9, 2, 0.1, 17, 17, 1.5])
k = 3
idx = np.argpartition(A, k)
print(idx)
# [4 0 7 3 1 2 6 5]
Run Code Online (Sandbox Code Playgroud)
这将返回k最小值.请注意,这些可能不是按排序顺序排列的.
print(A[idx[:k]])
# [ 0.1 1. 1.5]
Run Code Online (Sandbox Code Playgroud)
要获得k最大值使用
idx = np.argpartition(A, -k)
# [4 0 7 3 1 2 6 5]
A[idx[-k:]]
# [ 9. 17. 17.]
Run Code Online (Sandbox Code Playgroud)
警告:不要(重新)使用idx = np.argpartition(A, k); A[idx[-k:]]以获得最大的k值.这并不总是有效.例如,这些不是以下3个最大值x:
x = np.array([100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 0])
idx = np.argpartition(x, 3)
x[idx[-3:]]
array([ 70, 80, 100])
Run Code Online (Sandbox Code Playgroud)
这是一个比较np.argsort,它也可以工作,但只是对整个数组进行排序以获得结果.
In [2]: x = np.random.randn(100000)
In [3]: %timeit idx0 = np.argsort(x)[:100]
100 loops, best of 3: 8.26 ms per loop
In [4]: %timeit idx1 = np.argpartition(x, 100)[:100]
1000 loops, best of 3: 721 µs per loop
In [5]: np.alltrue(np.sort(np.argsort(x)[:100]) == np.sort(np.argpartition(x, 100)[:100]))
Out[5]: True
Run Code Online (Sandbox Code Playgroud)
Cor*_*mer 13
您可以使用numpy.argsort切片
>>> import numpy as np
>>> A = np.array([1, 7, 9, 2, 0.1, 17, 17, 1.5])
>>> np.argsort(A)[:3]
array([4, 0, 7], dtype=int32)
Run Code Online (Sandbox Code Playgroud)
对于n维数组,这个函数效果很好。不定数以可调用的形式返回。如果您想要返回索引列表,则需要在创建列表之前转置数组。
\n\n要检索k最大的,只需传入-k。
def get_indices_of_k_smallest(arr, k):\n idx = np.argpartition(arr.ravel(), k)\n return tuple(np.array(np.unravel_index(idx, arr.shape))[:, range(min(k, 0), max(k, 0))])\n # if you want it in a list of indices . . . \n # return np.array(np.unravel_index(idx, arr.shape))[:, range(k)].transpose().tolist()\nRun Code Online (Sandbox Code Playgroud)\n\n例子:
\n\nr = np.random.RandomState(1234)\narr = r.randint(1, 1000, 2 * 4 * 6).reshape(2, 4, 6)\n\nindices = get_indices_of_k_smallest(arr, 4)\nindices\n# (array([1, 0, 0, 1], dtype=int64),\n# array([3, 2, 0, 1], dtype=int64),\n# array([3, 0, 3, 3], dtype=int64))\n\narr[indices]\n# array([ 4, 31, 54, 77])\n\n%%timeit\nget_indices_of_k_smallest(arr, 4)\n# 17.1 \xc2\xb5s \xc2\xb1 651 ns per loop (mean \xc2\xb1 std. dev. of 7 runs, 100000 loops each)\nRun Code Online (Sandbox Code Playgroud)\n
| 归档时间: |
|
| 查看次数: |
42209 次 |
| 最近记录: |