有没有办法获取 numpy 数组(Python)每行的前 k 个值?

Jes*_*ABI 6 python loops numpy python-2.7 python-3.x

给定一个如下形式的 numpy 数组:

x = [[4.,3.,2.,1.,8.],[1.2,3.1,0.,9.2,5.5],[0.2,7.0,4.4,0.2,1.3]]
Run Code Online (Sandbox Code Playgroud)

有没有办法在 python 中保留每行中的前 3 个值并将其他值设置为零(无需显式循环)。上面示例的结果将是

x = [[4.,3.,0.,0.,8.],[0.,3.1,0.,9.2,5.5],[0.0,7.0,4.4,0.0,1.3]]
Run Code Online (Sandbox Code Playgroud)

一个例子的代码

import numpy as np
arr = np.array([1.2,3.1,0.,9.2,5.5,3.2])
indexes=arr.argsort()[-3:][::-1]
a = list(range(6))
A=set(indexes); B=set(a)
zero_ind=(B.difference(A)) 
arr[list(zero_ind)]=0
Run Code Online (Sandbox Code Playgroud)

输出:

array([0. , 0. , 0. , 9.2, 5.5, 3.2])
Run Code Online (Sandbox Code Playgroud)

上面是我的一维 numpy 数组的示例代码(有很多行)。循环遍历 numpy 数组的每一行并重复执行相同的计算将非常昂贵。有没有更简单的方法?

Mim*_*EAM 0

这是使用列表理解来查看数组并应用 keep_top_3 函数的替代方案

import numpy as np
import heapq

def keep_top_3(arr): 
    smallest = heapq.nlargest(3, arr)[-1]  # find the top 3 and use the smallest as cut off
    arr[arr < smallest] = 0 # replace anything lower than the cut off with 0
    return arr 

x = [[4.,3.,2.,1.,8.],[1.2,3.1,0.,9.2,5.5],[0.2,7.0,4.4,0.2,1.3]]
result = [keep_top_3(np.array(arr)) for arr  in x]
Run Code Online (Sandbox Code Playgroud)

我希望这有帮助 :)