我有一个非常大的数组,由0到N之间的整数组成,其中每个值至少出现一次.
我想知道,对于每个值k,我的数组中数组的值等于k的所有索引.
例如:
arr = np.array([0,1,2,3,2,1,0])
desired_output = {
0: np.array([0,6]),
1: np.array([1,5]),
2: np.array([2,4]),
3: np.array([3]),
}
Run Code Online (Sandbox Code Playgroud)
现在我正在完成这个循环range(N+1),并调用np.whereN次.
indices = {}
for value in range(max(arr)+1):
indices[value] = np.where(arr == value)[0]
Run Code Online (Sandbox Code Playgroud)
这个循环是我代码中最慢的部分.(arr==value评估和np.where通话都占用了大量的时间.)有没有更有效的方法来做到这一点?
我也试过玩,np.unique(arr, return_index=True)但只告诉我第一个索引,而不是所有索引.
方法#1
这是一个矢量化方法,将这些索引作为数组列表 -
sidx = arr.argsort()
unq, cut_idx = np.unique(arr[sidx],return_index=True)
indices = np.split(sidx,cut_idx)[1:]
Run Code Online (Sandbox Code Playgroud)
如果你想要将每个独特元素与其索引相对应的最终字典,最后我们可以使用循环理解 -
dict_out = {unq[i]:iterID for i,iterID in enumerate(indices)}
Run Code Online (Sandbox Code Playgroud)
方法#2
如果您只对阵列列表感兴趣,这里有一个替代性能 -
sidx = arr.argsort()
indices = np.split(sidx,np.flatnonzero(np.diff(arr[sidx])>0)+1)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
113 次 |
| 最近记录: |