我最近开始使用pytorch。我之前一直在使用 tensorflow 框架。我有一段用 tensorflow 实现的代码,现在我想将其转换为 pytorch 版本。
我是pytorch新手,不熟悉它的功能,转换过程也不是很顺利,想咨询一下。
这是我要转换的代码?
def kl_loss_compute(logits1, logits2):
""" KL loss
"""
pred1 = tf.nn.softmax(logits1)
pred2 = tf.nn.softmax(logits2)
loss = tf.reduce_mean(tf.reduce_sum(pred2 * tf.log(1e-8 + pred2 / (pred1 + 1e-8)), 1))
return loss
Run Code Online (Sandbox Code Playgroud)
logits1 和 logits2 是 FC 层的输出。它们的形状是 [batch, n]