考虑一下阵列 a
a = np.array([3, 3, np.nan, 3, 3, np.nan])
Run Code Online (Sandbox Code Playgroud)
我可以
np.isnan(a).argmax()
Run Code Online (Sandbox Code Playgroud)
但这需要找到所有np.nan只是为了找到第一个.
有更有效的方法吗?
我一直在试图弄清楚我是否可以传递一个参数,np.argpartition使np.nanget首先排序而不是last.
关于[dup]的编辑.
这个问题有几个不同的原因.
isnan.关于第二次[dup]的编辑.
解决平等和问题/答案仍然很老,很可能已经过时了.
fug*_*ede 11
也许值得研究numba.jit; 没有它,矢量化版本可能会在大多数情况下击败直接的纯Python搜索,但在编译代码后,普通搜索将起带头作用,至少在我的测试中:
In [63]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])
In [70]: %paste
import numba
def naive(a):
for i in range(len(a)):
if np.isnan(a[i]):
return i
def short(a):
return np.isnan(a).argmax()
@numba.jit
def naive_jit(a):
for i in range(len(a)):
if np.isnan(a[i]):
return i
@numba.jit
def short_jit(a):
return np.isnan(a).argmax()
## -- End pasted text --
In [71]: %timeit naive(a)
100 loops, best of 3: 7.22 ms per loop
In [72]: %timeit short(a)
The slowest run took 4.59 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 37.7 µs per loop
In [73]: %timeit naive_jit(a)
The slowest run took 6821.16 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.79 µs per loop
In [74]: %timeit short_jit(a)
The slowest run took 395.51 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 144 µs per loop
Run Code Online (Sandbox Code Playgroud)
编辑:正如@hpaulj在他们的回答中指出的那样,numpy实际上发布了一个优化的短路搜索,其性能与上面的JITted搜索相当:
In [26]: %paste
def plain(a):
return a.argmax()
@numba.jit
def plain_jit(a):
return a.argmax()
## -- End pasted text --
In [35]: %timeit naive(a)
100 loops, best of 3: 7.13 ms per loop
In [36]: %timeit plain(a)
The slowest run took 4.37 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 7.04 µs per loop
In [37]: %timeit naive_jit(a)
100000 loops, best of 3: 6.91 µs per loop
In [38]: %timeit plain_jit(a)
10000 loops, best of 3: 125 µs per loop
Run Code Online (Sandbox Code Playgroud)
我会提名
a.argmax()
Run Code Online (Sandbox Code Playgroud)
使用@fuglede's测试数组:
In [1]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])
In [2]: np.isnan(a).argmax()
Out[2]: 9999
In [3]: np.argmax(a)
Out[3]: 9999
In [4]: a.argmax()
Out[4]: 9999
In [5]: timeit a.argmax()
The slowest run took 29.94 ....
10000 loops, best of 3: 20.3 µs per loop
In [6]: timeit np.isnan(a).argmax()
The slowest run took 7.82 ...
1000 loops, best of 3: 462 µs per loop
Run Code Online (Sandbox Code Playgroud)
我没有numba安装,所以可以比较一下.但是我的加速比short大于@fuglede's6倍.
我在Py3中测试,它接受<np.nan,而Py2引发运行时警告.但代码搜索表明这不依赖于这种比较.
/numpy/core/src/multiarray/calculation.c PyArray_ArgMax玩轴(将感兴趣的东西移动到最后),并将动作委托给arg_func = PyArray_DESCR(ap)->f->argmax依赖于dtype的函数.
在numpy/core/src/multiarray/arraytypes.c.src它看起来像BOOL_argmax短路,一旦遇到一个回来True.
for (; i < n; i++) {
if (ip[i]) {
*max_ind = i;
return 0;
}
}
Run Code Online (Sandbox Code Playgroud)
并且@fname@_argmax最大限度地短路nan.np.nan也是'最大' argmin.
#if @isfloat@
if (@isnan@(mp)) {
/* nan encountered; it's maximal */
return 0;
}
#endif
Run Code Online (Sandbox Code Playgroud)
有经验的c编码员的评论是受欢迎的,但在我看来,至少对于np.nan一个平原来说,argmax我们能得到的速度一样快.
使用9999生成a显示a.argmax时间取决于该值,与短路一致.
这是一个使用以下方法的pythonic方法itertools.takewhile():
from itertools import takewhile
sum(1 for _ in takewhile(np.isfinite, a))
Run Code Online (Sandbox Code Playgroud)
使用generator_expression_within_ next方法进行基准测试:1
In [118]: a = np.repeat(a, 10000)
In [120]: %timeit next(i for i, j in enumerate(a) if np.isnan(j))
100 loops, best of 3: 12.4 ms per loop
In [121]: %timeit sum(1 for _ in takewhile(np.isfinite, a))
100 loops, best of 3: 11.5 ms per loop
Run Code Online (Sandbox Code Playgroud)
但仍然(到目前为止)比numpy方法慢:
In [119]: %timeit np.isnan(a).argmax()
100000 loops, best of 3: 16.8 µs per loop
Run Code Online (Sandbox Code Playgroud)
这种方法的问题在于使用 enumerate功能.它enumerate首先从numpy数组返回一个对象(这是一个像对象一样的迭代器),并且调用迭代器的生成器函数和next属性需要时间.
| 归档时间: |
|
| 查看次数: |
915 次 |
| 最近记录: |