如何测试所有行是否在numpy中相等

ele*_*ora 16 python arrays numpy

在numpy中,如果所有行在2d数组中相等,是否有一种很好的惯用方法?

我可以做点什么

np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))])
Run Code Online (Sandbox Code Playgroud)

这似乎将python列表与numpy数组混合,这些数组很难看,也可能很慢.

有更好/更整洁的方式吗?

Ale*_*ley 19

一种方法是检查数组的每一行arr是否等于第一行arr[0]:

(arr == arr[0]).all()
Run Code Online (Sandbox Code Playgroud)

==对于整数值,使用相等是正常的,但是如果arr包含浮点值,则可以使用它np.isclose来检查给定容差内的相等性:

np.isclose(a, a[0]).all()
Run Code Online (Sandbox Code Playgroud)

如果你的数组包含NaN并且你想避免棘手的NaN != NaN问题,你可以将这种方法与np.isnan:

(np.isclose(a, a[0]) | np.isnan(a)).all()
Run Code Online (Sandbox Code Playgroud)

  • 检查相等性,而不是等于0的差异,可能会更快一点. (2认同)
  • 现在有 `np.allclose`。 (2认同)

Ash*_*ary 5

只需检查数组中唯一项的数字是否为1:

>>> arr = np.array([[1]*10 for _ in xrange(5)])
>>> len(np.unique(arr)) == 1
True
Run Code Online (Sandbox Code Playgroud)

从unutbu的答案中获得灵感的解决方案:

>>> arr = np.array([[1]*10 for _ in xrange(5)])
>>> np.all(np.all(arr == arr[0,:], axis = 1))
True
Run Code Online (Sandbox Code Playgroud)

您的代码的一个问题是您在申请之前首先创建了整个列表np.all().由于您的版本中没有发生短路,而不是如果您使用all()带有生成器表达式的Python会更好:

时间比较:

>>> M = arr = np.array([[3]*100] + [[2]*100 for _ in xrange(1000)])
>>> %timeit np.all(np.all(arr == arr[0,:], axis = 1))
1000 loops, best of 3: 272 µs per loop
>>> %timeit (np.diff(M, axis=0) == 0).all()
1000 loops, best of 3: 596 µs per loop
>>> %timeit np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))])
100 loops, best of 3: 10.6 ms per loop
>>> %timeit all(np.array_equal(M[0], M[i]) for i in xrange(1,len(M)))
100000 loops, best of 3: 11.3 µs per loop

>>> M = arr = np.array([[2]*100 for _ in xrange(1000)])
>>> %timeit np.all(np.all(arr == arr[0,:], axis = 1))
1000 loops, best of 3: 330 µs per loop
>>> %timeit (np.diff(M, axis=0) == 0).all()
1000 loops, best of 3: 594 µs per loop
>>> %timeit np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))])
100 loops, best of 3: 9.51 ms per loop
>>> %timeit all(np.array_equal(M[0], M[i]) for i in xrange(1,len(M)))
100 loops, best of 3: 9.44 ms per loop
Run Code Online (Sandbox Code Playgroud)

  • @ user2179021我的系统需要"650μs",所以仍然慢于我的第二个答案. (2认同)

luc*_*yan 5

值得一提的是,上述版本不适用于多维数组。

例如:对于三维方形图像张量img[256, 256, 3] ,我们需要检查图像中是否存在相同的 RGB [256, 256] 层。在这种情况下,我们需要使用广播

(img == img[:, :, 0, np.newaxis]).all()

因为 simpleimg[:, :, 0]给了我们 [256, 256],但是我们需要 [256, 256, 1] 来通过层进行广播。