Numpy.where 与值列表一起使用

Ant*_*ier 10 python numpy

我有一个二维和一维数组。我希望找到至少包含一次 1d 数组中的值的两行,如下所示:

import numpy as np

A = np.array([[0, 3, 1],
           [9, 4, 6],
           [2, 7, 3],
           [1, 8, 9],
           [6, 2, 7],
           [4, 8, 0]])

B = np.array([0,1,2,3])

results = []

for elem in B:
    results.append(np.where(A==elem)[0])
Run Code Online (Sandbox Code Playgroud)

这有效并产生以下数组:

[array([0, 5], dtype=int64),
 array([0, 3], dtype=int64),
 array([2, 4], dtype=int64),
 array([0, 2], dtype=int64)]
Run Code Online (Sandbox Code Playgroud)

但这可能不是最好的处理方式。根据这个问题(搜索具有多个值的 Numpy 数组)给出的答案,我尝试了以下解决方案:

out1 = np.where(np.in1d(A, B))

num_arr = np.sort(B)
idx = np.searchsorted(B, A)
idx[idx==len(num_arr)] = 0 
out2 = A[A == num_arr[idx]]
Run Code Online (Sandbox Code Playgroud)

但这些给了我不正确的值:

In [36]: out1
Out[36]: (array([ 0,  1,  2,  6,  8,  9, 13, 17], dtype=int64),)

In [37]: out2
Out[37]: array([0, 3, 1, 2, 3, 1, 2, 0])
Run Code Online (Sandbox Code Playgroud)

感谢您的帮助

den*_*lov 6

如果您需要知道 A 的每一行是否包含数组 B 的任何元素,而不关心它是 B 的哪个特定元素,则可以使用以下脚本:

输入:

np.isin(A,B).sum(axis=1)>0 
Run Code Online (Sandbox Code Playgroud)

输出:

array([ True, False,  True,  True,  True,  True])
Run Code Online (Sandbox Code Playgroud)


Kas*_*mvd 2

由于您正在处理 2D 数组*B您可以使用广播来与 的 raveled 版本进行比较A。这将为您提供散乱形状的相应索引。然后您可以反转结果并使用 获取原始数组中相应的索引np.unravel_index

In [50]: d = np.where(B[:, None] == A.ravel())[1]

In [51]: np.unravel_index(d, A.shape)
Out[51]: (array([0, 5, 0, 3, 2, 4, 0, 2]), array([0, 2, 2, 0, 0, 1, 1, 2]))                 
                       ^
               # expected result 
Run Code Online (Sandbox Code Playgroud)

* 来自文档:对于 3 维数组,这在代码行方面肯定是高效的,并且对于小数据集,它在计算上也可以是高效的。然而,对于大型数据集,创建大型 3 维数组可能会导致性能下降。此外,广播是一种强大的工具,可用于编写简短且通常直观的代码,这些代码在 C 中非常有效地进行计算。但是,在某些情况下,广播会为特定算法使用不必要的大量内存。在这些情况下,最好用 Python 编写算法的外循环。这也可能会产生更可读的代码,因为随着广播中维度数量的增加,使用广播的算法往往会变得更难以解释。