这个问题与很多与警告有关的问题非常相似 RuntimeWarning: invalid value encountered in greater/less/etc
但是,我找不到解决我特定问题的方法,我认为应该有一个.
所以,我有一个numpy.ndarray类似于这个:
array([[ nan, 1., nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
...,
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]])
Run Code Online (Sandbox Code Playgroud)
我想计算array > 0.5,它给出了我想要的结果,但是有与之比较的警告nan:
__main__:1: RuntimeWarning: invalid value encountered in greater
Out[68]:
array([[False, True, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]], dtype=bool)
Run Code Online (Sandbox Code Playgroud)
我基本上想要计算array > 0.5,但没有出现警告.
我的限制:
with np.errstate(invalid='ignore'):我想出了一个简单的解决方案:
nan原始矩阵(array[np.isnan(array)] = -np.inf),在我进行比较后恢复它(array[array == -np.inf] = np.nan)但是我觉得这些计算只是浪费时间(我认为)它应该存在直接的方法来实现这一点.我一直在探索numpy.ma模块和numpy.where功能,但我找不到我想要的这种"直接"解决方案.
有什么想法吗?
只要比较包含至少一个NaN的数组,就会发出警告.解决方案是用于masking仅比较非NaN元素,我们将尝试使用通用解决方案来覆盖所有类型的比较comparison based NumPy ufuncs,如下所示 -
def compare_nan_array(func, a, thresh):
out = ~np.isnan(a)
out[out] = func(a[out] , thresh)
return out
Run Code Online (Sandbox Code Playgroud)
这个想法是:
获取非NaN的掩码.
用它来从输入数组中获取非NaN值.然后执行所需的比较(大于,大于等等)以获得另一个掩码,该掩码表示掩蔽位置的比较掩码输出.
使用它来细化非NaN的掩码,这是最终输出.
样品运行 -
In [41]: np.random.seed(0)
In [42]: a = np.random.randint(0,9,(4,5)).astype(float)
In [43]: a.ravel()[np.random.choice(a.size, 16, replace=0)] = np.nan
In [44]: a
Out[44]:
array([[ nan, nan, nan, nan, nan],
[ nan, nan, nan, 4., 7.],
[ nan, nan, nan, 1., nan],
[ nan, 7., nan, nan, nan]])
In [45]: a > 5 # Shows warning with the usual comparison
__main__:1: RuntimeWarning: invalid value encountered in greater
Out[45]:
array([[False, False, False, False, False],
[False, False, False, False, True],
[False, False, False, False, False],
[False, True, False, False, False]], dtype=bool)
# With suggested masking based method
In [46]: compare_nan_array(np.greater, a, 5)
Out[46]:
array([[False, False, False, False, False],
[False, False, False, False, True],
[False, False, False, False, False],
[False, True, False, False, False]], dtype=bool)
Run Code Online (Sandbox Code Playgroud)
让我们通过测试来测试通用行为lesser than 5-
In [47]: a < 5
__main__:1: RuntimeWarning: invalid value encountered in less
Out[47]:
array([[False, False, False, False, False],
[False, False, False, True, False],
[False, False, False, True, False],
[False, False, False, False, False]], dtype=bool)
In [48]: compare_nan_array(np.less, a, 5)
Out[48]:
array([[False, False, False, False, False],
[False, False, False, True, False],
[False, False, False, True, False],
[False, False, False, False, False]], dtype=bool)
Run Code Online (Sandbox Code Playgroud)
有一个更好的方法 - 您不想永远抑制警告,因为它可以帮助您稍后发现其他错误。
遵循在这个问题中找到的建议:RuntimeWarning:individual value seen individe
如果结果是你想要的,你可以写:
with np.errstate(invalid='ignore'):
result = (array > 0.5)
# ... use result, and your warnings are not suppressed.
Run Code Online (Sandbox Code Playgroud)
否则,您可以通过复制数组来满足您的限制:
to_compare = array.copy()
to_compare[np.isnan(to_compare)] = 0.5 # you don't need -np.inf, anything <= 0.5 is OK
result = (to_compare > 0.5)
Run Code Online (Sandbox Code Playgroud)
而且您不需要“恢复”阵列中的 NaN。
| 归档时间: |
|
| 查看次数: |
3482 次 |
| 最近记录: |