快速查找数组数组中的数组索引

pet*_*hka 3 python arrays search numpy multidimensional-array

假设我有一个长度为4的numpy数组:

In [41]: arr
Out[41]:
array([[  1,  15,   0,   0],
       [ 30,  10,   0,   0],
       [ 30,  20,   0,   0],
       ...,
       [104, 139, 146,  75],
       [  9,  11, 146,  74],
       [  9, 138, 146,  75]], dtype=uint8)
Run Code Online (Sandbox Code Playgroud)

我想知道:

  1. arr包括真的[1, 2, 3, 4]吗?
  2. 如果真有什么指数[1, 2, 3, 4]arr

我想尽可能快地发现它.

假设arr包含8550420个元素.我已经检查了几种方法timeit:

  1. 只是为了检查而不获得索引:any(all([1, 2, 3, 4] == elt) for elt in arr).在我的机器上运行10次,平均花费了15.5秒
  2. for基于解决方案:

    for i,e in enumerate(arr): if list(e) == [1, 2, 3, 4]: break

    它平均花了大约5.7秒

是否存在一些更快的解决方案,例如基于numpy?

unu*_*tbu 6

这是Jaime的想法,我只是喜欢它:

import numpy as np

def asvoid(arr):
    """View the array as dtype np.void (bytes)
    This collapses ND-arrays to 1D-arrays, so you can perform 1D operations on them.
    https://stackoverflow.com/a/16216866/190597 (Jaime)"""    
    arr = np.ascontiguousarray(arr)
    return arr.view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[-1])))

def find_index(arr, x):
    arr_as1d = asvoid(arr)
    x = asvoid(x)
    return np.nonzero(arr_as1d == x)[0]


arr = np.array([[  1,  15,   0,   0],
                [ 30,  10,   0,   0],
                [ 30,  20,   0,   0],
                [1, 2, 3, 4],
                [104, 139, 146,  75],
                [  9,  11, 146,  74],
                [  9, 138, 146,  75]], dtype='uint8')

arr = np.tile(arr,(1221488,1))
x = np.array([1,2,3,4], dtype='uint8')

print(find_index(arr, x))
Run Code Online (Sandbox Code Playgroud)

产量

[      3      10      17 ..., 8550398 8550405 8550412]
Run Code Online (Sandbox Code Playgroud)

我们的想法是将数组的每一视为一个字符串.例如,

In [15]: x
Out[15]: 
array([^A^B^C^D], 
      dtype='|V4')
Run Code Online (Sandbox Code Playgroud)

这些字符串看起来像垃圾,但它们实际上只是每行被视为字节的基础数据.然后,您可以比较arr_as1d == x以找到相等的x.


还有另一种方法:

def find_index2(arr, x):
    return np.where((arr == x).all(axis=1))[0]
Run Code Online (Sandbox Code Playgroud)

但事实证明并不那么快:

In [34]: %timeit find_index(arr, x)
1 loops, best of 3: 209 ms per loop

In [35]: %timeit find_index2(arr, x)
1 loops, best of 3: 370 ms per loop
Run Code Online (Sandbox Code Playgroud)