Numpy 数组集差异

sla*_*law 6 python arrays numpy

我有两个具有重叠行的 numpy 数组:

import numpy as np

a = np.array([[1,2], [1,5], [3,4], [3,5], [4,1], [4,6]])
b = np.array([[1,5], [3,4], [4,6]])
Run Code Online (Sandbox Code Playgroud)

你可以假设:

  1. 行已排序
  2. 每个数组中的行是唯一的
  3. 数组b始终是数组的子集a

我想得到一个数组,其中包含所有a不在b.

IE,:

[[1 2]
 [3 5]
 [4 1]]
Run Code Online (Sandbox Code Playgroud)

考虑到a并且b可能非常非常大,解决这个问题的最有效方法是什么?

BPL*_*BPL 6

这是您的问题的可能解决方案:

import numpy as np

a = np.array([[1, 2], [3, 4], [3, 5], [4, 1], [4, 6]])
b = np.array([[3, 4], [4, 6]])

a1_rows = a.view([('', a.dtype)] * a.shape[1])
a2_rows = b.view([('', b.dtype)] * b.shape[1])
c = np.setdiff1d(a1_rows, a2_rows).view(a.dtype).reshape(-1, a.shape[1])
print c
Run Code Online (Sandbox Code Playgroud)

我认为在这里使用numpy.setdiff1d是正确的选择