pytorch代码中的KL散度与公式有何关系?

Sta*_*ham 3 python autoencoder pytorch loss-function

在 VAE 教程中,两个正态分布的 kl 散度定义为: 在此输入图像描述

而在很多代码中,例如hereherehere,代码实现为:

 KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
Run Code Online (Sandbox Code Playgroud)

或者

def latent_loss(z_mean, z_stddev):
    mean_sq = z_mean * z_mean
    stddev_sq = z_stddev * z_stddev
    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)
Run Code Online (Sandbox Code Playgroud)

它们有何关系?为什么代码中没有“tr”或“.transpose()”?

jod*_*dag 6

您发布的代码中的表达式假设 X 是不相关的多元高斯随机变量。协方差矩阵的行列式中缺少交叉项,这一点很明显。因此,均值向量和协方差矩阵采用以下形式

在此输入图像描述

使用它,我们可以快速推导出原始表达式的组件的以下等效表示

在此输入图像描述

将这些代回原始表达式给出

在此输入图像描述