从另一个数组中获取匹配项的索引

Vin*_*ent 1 python numpy

给定两个 np.arrays;

a = np.array([1, 6, 5, 3, 8, 345, 34, 6, 2, 867])
b = np.array([867, 8, 34, 75])
Run Code Online (Sandbox Code Playgroud)

我想得到一个与 b 具有相同维度的 np.array,其中每个值是 b 中的值出现在 a 中的索引,或者 np.nan 如果 b 中的值不存在于 a 中。

结果应该是;

[9, 4, 6, nan]
Run Code Online (Sandbox Code Playgroud)

a 和 b 将始终具有相同的维度数,但维度的大小可能不同。

就像是;

(伪代码)

c = np.where(b in a)
Run Code Online (Sandbox Code Playgroud)

但适用于数组(“in”不适用)

我更喜欢“单行”或至少是完全在阵列级别的解决方案,并且不需要循环。

谢谢!

Div*_*kar 6

方法#1

这是一个np.searchsorted-

def find_indices(a,b,invalid_specifier=-1):
    # Search for matching indices for each b in sorted version of a. 
    # We use sorter arg to account for the case when a might not be sorted 
    # using argsort on a
    sidx = a.argsort()
    idx = np.searchsorted(a,b,sorter=sidx)

    # Remove out of bounds indices as they wont be matches
    idx[idx==len(a)] = 0

    # Get traced back indices corresponding to original version of a
    idx0 = sidx[idx]
    
    # Mask out invalid ones with invalid_specifier and return
    return np.where(a[idx0]==b, idx0, invalid_specifier)
Run Code Online (Sandbox Code Playgroud)

给定样本的输出 -

In [41]: find_indices(a, b, invalid_specifier=np.nan)
Out[41]: array([ 9.,  4.,  6., nan])
Run Code Online (Sandbox Code Playgroud)

方法#2

另一个基于lookup正数 -

def find_indices_lookup(a,b,invalid_specifier=-1):
    # Setup array where we will assign ranged numbers
    N = max(a.max(), b.max())+1
    lookup = np.full(N, invalid_specifier)
    
    # We index into lookup with b to trace back the positions. Non matching ones
    # would have invalid_specifier values as wount had been indexed by ranged ones
    lookup[a] = np.arange(len(a))
    indices  = lookup[b]
    return indices
Run Code Online (Sandbox Code Playgroud)

基准测试

问题中没有提到效率作为一项要求,但无循环要求可能会出现在那里。使用尝试重新表示给定示例设置的设置进行测试,但通过1000x以下方式对其进行扩展:

In [98]: a = np.random.permutation(np.unique(np.random.randint(0,20000,10000)))

In [99]: b = np.random.permutation(np.unique(np.random.randint(0,20000,4000)))

# Solutions from this post
In [100]: %timeit find_indices(a,b,invalid_specifier=np.nan)
     ...: %timeit find_indices_lookup(a,b,invalid_specifier=np.nan)
1.35 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
220 µs ± 30.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

# @Quang Hoang-soln2
In [101]: %%timeit
     ...: commons, idx_a, idx_b = np.intersect1d(a,b, return_indices=True)
     ...: orders = np.argsort(idx_b)
     ...: output = np.full(b.shape, np.nan)
     ...: output[orders] = idx_a[orders]
1.63 ms ± 59.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

# @Quang Hoang-soln1
In [102]: %%timeit
     ...: s = b == a[:,None]
     ...: np.where(s.any(0), np.argmax(s,0), np.nan)
137 ms ± 9.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Run Code Online (Sandbox Code Playgroud)