交叉熵函数(python)

Jas*_*y.W 11 python machine-learning neural-network cross-entropy

我正在学习神经网络,我想cross_entropy在python中编写一个函数.在哪里定义为

交叉熵

其中N是样本的数目,k是类的数量,log是自然对数,t_i,j是1,如果样品i是在类j0否则,和p_i,j是预测的概率的样品i是在类j.要避免使用对数的数字问题,请将预测剪辑到[10^{?12}, 1 ? 10^{?12}]范围.

根据上面的描述,我通过clippint预测[epsilon, 1 ? epsilon]范围来记下代码,然后根据上面的公式计算交叉熵.

def cross_entropy(predictions, targets, epsilon=1e-12):
    """
    Computes cross entropy between targets (encoded as one-hot vectors)
    and predictions. 
    Input: predictions (N, k) ndarray
           targets (N, k) ndarray        
    Returns: scalar
    """
    predictions = np.clip(predictions, epsilon, 1. - epsilon)
    ce = - np.mean(np.log(predictions) * targets) 
    return ce
Run Code Online (Sandbox Code Playgroud)

以下代码将用于检查功能cross_entropy是否正确.

predictions = np.array([[0.25,0.25,0.25,0.25],
                        [0.01,0.01,0.01,0.96]])
targets = np.array([[0,0,0,1],
                  [0,0,0,1]])
ans = 0.71355817782  #Correct answer
x = cross_entropy(predictions, targets)
print(np.isclose(x,ans))
Run Code Online (Sandbox Code Playgroud)

上述代码的输出为False,即我的定义函数的代码cross_entropy不正确.然后我打印出结果cross_entropy(predictions, targets).它给出0.178389544455了正确的结果ans = 0.71355817782.有人可以帮我查一下我的代码有什么问题吗?

Das*_*enz 19

你根本不是那么遥远,但记住你正在取N个和的平均值,其中N = 2(在这种情况下).所以你的代码可以读取:

def cross_entropy(predictions, targets, epsilon=1e-12):
    """
    Computes cross entropy between targets (encoded as one-hot vectors)
    and predictions. 
    Input: predictions (N, k) ndarray
           targets (N, k) ndarray        
    Returns: scalar
    """
    predictions = np.clip(predictions, epsilon, 1. - epsilon)
    N = predictions.shape[0]
    ce = -np.sum(targets*np.log(predictions+1e-9))/N
    return ce

predictions = np.array([[0.25,0.25,0.25,0.25],
                        [0.01,0.01,0.01,0.96]])
targets = np.array([[0,0,0,1],
                   [0,0,0,1]])
ans = 0.71355817782  #Correct answer
x = cross_entropy(predictions, targets)
print(np.isclose(x,ans))
Run Code Online (Sandbox Code Playgroud)

在这里,我认为如果你坚持下去会更清楚一些np.sum().另外,我在其中添加了1e-9 np.log()以避免在计算中出现log(0)的可能性.希望这可以帮助!

注意:根据@Peter的评论,1e-9如果你的epsilon值大于,则偏移量确实是多余的0.