测试numpy数组是否只包含零

IUn*_*own 78 python numpy

我们用零填充numpy数组如下:

np.zeros((N,N+1))
Run Code Online (Sandbox Code Playgroud)

但是我们如何检查给定n*n numpy数组矩阵中的所有元素是否为零.
如果所有值都为零,则该方法只需返回True.

Stu*_*erg 136

这里发布的其他答案将有效,但使用的最清晰,最有效的功能是numpy.any():

>>> all_zeros = not np.any(a)
Run Code Online (Sandbox Code Playgroud)

要么

>>> all_zeros = not a.any()
Run Code Online (Sandbox Code Playgroud)
  • 这是首选,numpy.all(a==0)因为它使用较少的RAM.(它不需要a==0术语创建的临时数组.)
  • 此外,它比numpy.count_nonzero(a)在找到第一个非零元素时立即返回更快.
    • 编辑:正如@Rachel在评论中指出的那样,np.any()不再使用"短路"逻辑,因此您不会看到小型阵列的速度优势.

  • 一分钟前,numpy 的 `any` 和 `all` **不会**短路。我相信它们是`logical_or.reduce` 和`logical_and.reduce` 的糖。相互比较和我的短路 `is_in`:`all_false = np.zeros(10**8)` `all_true = np.ones(10**8)` `%timeit np.any(all_false) 91.5 ms ± 1.82 ms 每个循环``%timeit np.any(all_true) 93.7 ms ± 6.16 ms 每个循环``%timeit is_in(1, all_true) 293 ns ± 1.65 ns 每个循环` (4认同)
  • 这是一个很好的观点,谢谢。看起来短路 *used* 是行为,但在某些时候丢失了。[这个问题](/sf/ask/3204008811/)的答案中有一些有趣的讨论。 (4认同)

Pra*_*mar 62

看看numpy.count_nonzero.

>>> np.count_nonzero(np.eye(4))
4
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]])
5
Run Code Online (Sandbox Code Playgroud)

  • 只有当所有值都为0时,你才希望`np.count_nonzero(np.eye(4))`返回'True`. (6认同)

J. *_*rde 44

我在这里使用np.all,如果你有一个数组a:

>>> np.all(a==0)
Run Code Online (Sandbox Code Playgroud)

  • 我喜欢这个答案也检查非零值.例如,可以通过执行`np.all(a == a [0])`来检查数组中的所有元素是否相同.非常感谢! (3认同)
  • 该解决方案也比“np.count_nonzero”更高效。%timeit num_of_non_zeros = np.count_nonzero(zeros_vector) 每个循环 18.2 µs ± 386 ns(平均值 ± 标准偏差,7 次运行,每次 100000 次循环) %timeit num_of_non_zeros = np.all((zeros_vector == 0)) 7.31 µs ±每个循环 41.6 ns(7 次运行的平均值 ± 标准差,每次 100000 个循环) (2认同)

Rac*_*hel 10

正如另一个答案所说,如果您知道这0是数组中唯一可能的虚假元素,则可以利用真/假评估。如果数组中没有任何真元素,则数组中的所有元素都是假的。*

>>> a = np.zeros(10)
>>> not np.any(a)
True
Run Code Online (Sandbox Code Playgroud)

然而,答案声称这any比其他选项更快,部分原因是短路。截至 2018 年,Numpy'sallany 没有短路.

如果你经常做这种事情,很容易使用以下方法制作自己的短路版本numba

import numba as nb

# short-circuiting replacement for np.any()
@nb.jit(nopython=True)
def sc_any(array):
    for x in array.flat:
        if x:
            return True
    return False

# short-circuiting replacement for np.all()
@nb.jit(nopython=True)
def sc_all(array):
    for x in array.flat:
        if not x:
            return False
    return True
Run Code Online (Sandbox Code Playgroud)

即使没有短路,这些也往往比 Numpy 的版本快。count_nonzero是最慢的。

检查性能的一些输入:

import numpy as np

n = 10**8
middle = n//2
all_0 = np.zeros(n, dtype=int)
all_1 = np.ones(n, dtype=int)
mid_0 = np.ones(n, dtype=int)
mid_1 = np.zeros(n, dtype=int)
np.put(mid_0, middle, 0)
np.put(mid_1, middle, 1)
# mid_0 = [1 1 1 ... 1 0 1 ... 1 1 1]
# mid_1 = [0 0 0 ... 0 1 0 ... 0 0 0]
Run Code Online (Sandbox Code Playgroud)

查看:

## count_nonzero
%timeit np.count_nonzero(all_0) 
# 220 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.count_nonzero(all_1)
# 150 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

### all
# np.all
%timeit np.all(all_1)
%timeit np.all(mid_0)
%timeit np.all(all_0)
# 56.8 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.4 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 55.9 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_all
%timeit sc_all(all_1)
%timeit sc_all(mid_0)
%timeit sc_all(all_0)
# 44.4 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.7 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 288 ns ± 6.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

### any
# np.any
%timeit np.any(all_0)
%timeit np.any(mid_1)
%timeit np.any(all_1)
# 60.7 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 60 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_any
%timeit sc_any(all_0)
%timeit sc_any(mid_1)
%timeit sc_any(all_1)
# 41.7 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.4 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 287 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
Run Code Online (Sandbox Code Playgroud)

* 有用的allany等效的:

np.all(a) == np.logical_not(np.any(np.logical_not(a)))
np.any(a) == np.logical_not(np.all(np.logical_not(a)))
not np.all(a) == np.any(np.logical_not(a))
not np.any(a) == np.all(np.logical_not(a))
Run Code Online (Sandbox Code Playgroud)