小编ahu*_*y_1的帖子

如何将一些 tensorflow 代码转换为 pytorch 版本

我最近开始使用pytorch。我之前一直在使用 tensorflow 框架。我有一段用 tensorflow 实现的代码,现在我想将其转换为 pytorch 版本。

我是pytorch新手,不熟悉它的功能,转换过程也不是很顺利,想咨询一下。

这是我要转换的代码?

def kl_loss_compute(logits1, logits2):
    """ KL loss
    """
    pred1 = tf.nn.softmax(logits1)
    pred2 = tf.nn.softmax(logits2)
    loss = tf.reduce_mean(tf.reduce_sum(pred2 * tf.log(1e-8 + pred2 / (pred1 + 1e-8)), 1))

    return loss
Run Code Online (Sandbox Code Playgroud)

蟒蛇:3.6,Ubuntu:16.04

logits1 和 logits2 是 FC 层的输出。它们的形状是 [batch, n]

python tensorflow pytorch

3
推荐指数
1
解决办法
5684
查看次数

标签 统计

python ×1

pytorch ×1

tensorflow ×1