Che*_*hie 20 python numpy softmax
我尝试使用以下代码实现soft-max(out_vec是numpy浮点数的向量):
numerator = np.exp(out_vec)
denominator = np.sum(np.exp(out_vec))
out_vec = numerator/denominator
Run Code Online (Sandbox Code Playgroud)
但是,因为有一个溢出错误np.exp(out_vec).因此,我检查(手动)上限np.exp()是什么,并发现np.exp(709)是一个数字,但np.exp(710)被认为是np.inf.因此,为了避免溢出错误,我修改了我的代码如下:
out_vec[out_vec > 709] = 709 #prevent np.exp overflow
numerator = np.exp(out_vec)
denominator = np.sum(np.exp(out_vec))
out_vec = numerator/denominator
Run Code Online (Sandbox Code Playgroud)
现在,我得到一个不同的错误:
RuntimeWarning: invalid value encountered in greater out_vec[out_vec > 709] = 709
Run Code Online (Sandbox Code Playgroud)
我添加的产品线有什么问题?我查了一下这个特定的错误,我找到的只是人们对如何忽略错误的建议.简单地忽略错误对我没有帮助,因为每次我的代码遇到这个错误时都不会给出通常的结果.
kvo*_*iev 29
你的问题是由引起的NaN或Inf你的元素out_vec阵列.您可以使用以下代码来避免此问题:
if np.isnan(np.sum(out_vec)):
out_vec = out_vec[~numpy.isnan(out_vec)] # just remove nan elements from vector
out_vec[out_vec > 709] = 709
...
Run Code Online (Sandbox Code Playgroud)
或者您可以使用以下代码将NaN值保留在数组中:
out_vec[ np.array([e > 709 if ~np.isnan(e) else False for e in out_vec], dtype=bool) ] = 709
Run Code Online (Sandbox Code Playgroud)
jue*_*erg 13
在我的情况下,在比较之前调用此警告时没有显示警告(我将NaN值进行比较)
np.warnings.filterwarnings('ignore')
Run Code Online (Sandbox Code Playgroud)
小智 6
IMO更好的方法是使用更加数值稳定的指数和的实现.
from scipy.misc import logsumexp
out_vec = np.exp(out_vec - logsumexp(out_vec))
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
55541 次 |
| 最近记录: |