给定两个 np.arrays;
a = np.array([1, 6, 5, 3, 8, 345, 34, 6, 2, 867])
b = np.array([867, 8, 34, 75])
Run Code Online (Sandbox Code Playgroud)
我想得到一个与 b 具有相同维度的 np.array,其中每个值是 b 中的值出现在 a 中的索引,或者 np.nan 如果 b 中的值不存在于 a 中。
结果应该是;
[9, 4, 6, nan]
Run Code Online (Sandbox Code Playgroud)
a 和 b 将始终具有相同的维度数,但维度的大小可能不同。
就像是;
(伪代码)
c = np.where(b in a)
Run Code Online (Sandbox Code Playgroud)
但适用于数组(“in”不适用)
我更喜欢“单行”或至少是完全在阵列级别的解决方案,并且不需要循环。
谢谢!
方法#1
这是一个np.searchsorted-
def find_indices(a,b,invalid_specifier=-1):
# Search for matching indices for each b in sorted version of a.
# We use sorter arg to account for the case when a might not be sorted
# using argsort on a
sidx = a.argsort()
idx = np.searchsorted(a,b,sorter=sidx)
# Remove out of bounds indices as they wont be matches
idx[idx==len(a)] = 0
# Get traced back indices corresponding to original version of a
idx0 = sidx[idx]
# Mask out invalid ones with invalid_specifier and return
return np.where(a[idx0]==b, idx0, invalid_specifier)
Run Code Online (Sandbox Code Playgroud)
给定样本的输出 -
In [41]: find_indices(a, b, invalid_specifier=np.nan)
Out[41]: array([ 9., 4., 6., nan])
Run Code Online (Sandbox Code Playgroud)
方法#2
另一个基于lookup正数 -
def find_indices_lookup(a,b,invalid_specifier=-1):
# Setup array where we will assign ranged numbers
N = max(a.max(), b.max())+1
lookup = np.full(N, invalid_specifier)
# We index into lookup with b to trace back the positions. Non matching ones
# would have invalid_specifier values as wount had been indexed by ranged ones
lookup[a] = np.arange(len(a))
indices = lookup[b]
return indices
Run Code Online (Sandbox Code Playgroud)
问题中没有提到效率作为一项要求,但无循环要求可能会出现在那里。使用尝试重新表示给定示例设置的设置进行测试,但通过1000x以下方式对其进行扩展:
In [98]: a = np.random.permutation(np.unique(np.random.randint(0,20000,10000)))
In [99]: b = np.random.permutation(np.unique(np.random.randint(0,20000,4000)))
# Solutions from this post
In [100]: %timeit find_indices(a,b,invalid_specifier=np.nan)
...: %timeit find_indices_lookup(a,b,invalid_specifier=np.nan)
1.35 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
220 µs ± 30.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
# @Quang Hoang-soln2
In [101]: %%timeit
...: commons, idx_a, idx_b = np.intersect1d(a,b, return_indices=True)
...: orders = np.argsort(idx_b)
...: output = np.full(b.shape, np.nan)
...: output[orders] = idx_a[orders]
1.63 ms ± 59.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# @Quang Hoang-soln1
In [102]: %%timeit
...: s = b == a[:,None]
...: np.where(s.any(0), np.argmax(s,0), np.nan)
137 ms ± 9.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Run Code Online (Sandbox Code Playgroud)