找到第一个np.nan值的最有效方法是什么?

piR*_*red 13 python numpy

考虑一下阵列 a

a = np.array([3, 3, np.nan, 3, 3, np.nan])
Run Code Online (Sandbox Code Playgroud)

我可以

np.isnan(a).argmax()
Run Code Online (Sandbox Code Playgroud)

但这需要找到所有np.nan只是为了找到第一个.
有更有效的方法吗?


我一直在试图弄清楚我是否可以传递一个参数,np.argpartition使np.nanget首先排序而不是last.


关于[dup]的编辑.
这个问题有几个不同的原因.

  1. 这个问题和答案涉及价值观的平等.这是关于isnan.
  2. 那些答案都遭遇了我的答案面临的同样问题.请注意,我提供了一个完全有效的答案,但强调它的效率低下.我正在寻找解决效率低下的问题.

关于第二次[dup]的编辑.

解决平等和问题/答案仍然很老,很可能已经过时了.

fug*_*ede 11

也许值得研究numba.jit; 没有它,矢量化版本可能会在大多数情况下击败直接的纯Python搜索,但在编译代码后,普通搜索将起带头作用,至少在我的测试中:

In [63]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])

In [70]: %paste
import numba

def naive(a):
        for i in range(len(a)):
                if np.isnan(a[i]):
                        return i

def short(a):
        return np.isnan(a).argmax()

@numba.jit
def naive_jit(a):
        for i in range(len(a)):
                if np.isnan(a[i]):
                        return i

@numba.jit
def short_jit(a):
        return np.isnan(a).argmax()
## -- End pasted text --

In [71]: %timeit naive(a)
100 loops, best of 3: 7.22 ms per loop

In [72]: %timeit short(a)
The slowest run took 4.59 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 37.7 µs per loop

In [73]: %timeit naive_jit(a)
The slowest run took 6821.16 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.79 µs per loop

In [74]: %timeit short_jit(a)
The slowest run took 395.51 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 144 µs per loop
Run Code Online (Sandbox Code Playgroud)

编辑:正如@hpaulj在他们的回答中指出的那样,numpy实际上发布了一个优化的短路搜索,其性能与上面的JITted搜索相当:

In [26]: %paste
def plain(a):
        return a.argmax()

@numba.jit
def plain_jit(a):
        return a.argmax()
## -- End pasted text --

In [35]: %timeit naive(a)
100 loops, best of 3: 7.13 ms per loop

In [36]: %timeit plain(a)
The slowest run took 4.37 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 7.04 µs per loop

In [37]: %timeit naive_jit(a)
100000 loops, best of 3: 6.91 µs per loop

In [38]: %timeit plain_jit(a)
10000 loops, best of 3: 125 µs per loop
Run Code Online (Sandbox Code Playgroud)


hpa*_*ulj 8

我会提名

a.argmax()
Run Code Online (Sandbox Code Playgroud)

使用@fuglede's测试数组:

In [1]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])
In [2]: np.isnan(a).argmax()
Out[2]: 9999
In [3]: np.argmax(a)
Out[3]: 9999
In [4]: a.argmax()
Out[4]: 9999

In [5]: timeit a.argmax()
The slowest run took 29.94 ....
10000 loops, best of 3: 20.3 µs per loop

In [6]: timeit np.isnan(a).argmax()
The slowest run took 7.82 ...
1000 loops, best of 3: 462 µs per loop
Run Code Online (Sandbox Code Playgroud)

我没有numba安装,所以可以比较一下.但是我的加速比short大于@fuglede's6倍.

我在Py3中测试,它接受<np.nan,而Py2引发运行时警告.但代码搜索表明这不依赖于这种比较.

/numpy/core/src/multiarray/calculation.c PyArray_ArgMax玩轴(将感兴趣的东西移动到最后),并将动作委托给arg_func = PyArray_DESCR(ap)->f->argmax依赖于dtype的函数.

numpy/core/src/multiarray/arraytypes.c.src它看起来像BOOL_argmax短路,一旦遇到一个回来True.

for (; i < n; i++) {
    if (ip[i]) {
        *max_ind = i;
        return 0;
    }
}
Run Code Online (Sandbox Code Playgroud)

并且@fname@_argmax最大限度地短路nan.np.nan也是'最大' argmin.

#if @isfloat@
    if (@isnan@(mp)) {
        /* nan encountered; it's maximal */
        return 0;
    }
#endif
Run Code Online (Sandbox Code Playgroud)

有经验的c编码员的评论是受欢迎的,但在我看来,至少对于np.nan一个平原来说,argmax我们能得到的速度一样快.

使用9999生成a显示a.argmax时间取决于该值,与短路一致.


Kas*_*mvd 6

这是一个使用以下方法的pythonic方法itertools.takewhile():

from itertools import takewhile
sum(1 for _ in takewhile(np.isfinite, a))
Run Code Online (Sandbox Code Playgroud)

使用generator_expression_within_ next方法进行基准测试:1

In [118]: a = np.repeat(a, 10000)

In [120]: %timeit next(i for i, j in enumerate(a) if np.isnan(j))
100 loops, best of 3: 12.4 ms per loop

In [121]: %timeit sum(1 for _ in takewhile(np.isfinite, a))
100 loops, best of 3: 11.5 ms per loop
Run Code Online (Sandbox Code Playgroud)

但仍然(到目前为止)比numpy方法慢:

In [119]: %timeit np.isnan(a).argmax()
100000 loops, best of 3: 16.8 µs per loop
Run Code Online (Sandbox Code Playgroud)

这种方法的问题在于使用 enumerate功能.它enumerate首先从numpy数组返回一个对象(这是一个像对象一样的迭代器),并且调用迭代器的生成器函数和next属性需要时间.

  • 这个数组对于这些方法而言太小了,无法击败numpy的函数.可能在更大的阵列上试试吗? (3认同)
  • @ayhan实际上这是一个重复的数组.我只是忘了添加相关命令.这是一个新的. (2认同)