小编Eth*_* Li的帖子

尝试理解 AutoGraph 和 tf.function:tf.function 中的打印丢失

def train_one_step():
    with tf.GradientTape() as tape:
        a = tf.random.normal([1, 3, 1])
        b = tf.random.normal([1, 3, 1])
        loss = mse(a, b)

    tf.print('inner tf print', loss)
    print("inner py print", loss)

    return loss


@tf.function
def train():
    loss = train_one_step()

    tf.print('outer tf print', loss)
    print('outer py print', loss)

    return loss

loss = train()
tf.print('outest tf print', loss)
print("outest py print", loss)
Run Code Online (Sandbox Code Playgroud)

我试图更多地了解 tf.function 。我用不同的方法在四个地方打印了损失。它产生这样的结果

inner py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
outer py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
inner tf print 1.82858419
outer tf print 1.82858419 …
Run Code Online (Sandbox Code Playgroud)

tensorflow tensorflow2.0

5
推荐指数
1
解决办法
3251
查看次数

标签 统计

tensorflow ×1

tensorflow2.0 ×1