张量流v2中的变量赋值

use*_*953 3 python tensorflow

我正在将代码转换为 Tensorflow v2,但不断收到以下错误:

AssertionError:调用了引用已删除变量的函数。这可能意味着函数局部变量已创建,但未在程序的其他位置引用。这通常是一个错误;考虑在第一次调用时将变量存储在对象属性中。

这是重现错误的最小示例

import tensorflow as tf

class TEST:
    def __init__(self, a=1):
        self.a = tf.Variable(a)

    @tf.function
    def increment(self):
        self.a = self.a + 1
        return self.a

tst = TEST()
tst.increment()
Run Code Online (Sandbox Code Playgroud)

我应该如何解决这个问题?

jde*_*esa 5

当你这样做时:

self.a = self.a + 1
Run Code Online (Sandbox Code Playgroud)

您将使用该操作的结果覆盖 中的引用self.a,该引用最初与上面创建的变量相关联。您不会更新 TensorFlow 变量的值,而只是替换 Python 引用。您正在创建的新张量( 的结果self.a + 1)反过来在其计算中使用该变量。问题是,那一刻self.a被覆盖了,变量被遗忘了,不能再使用了。这有点像先有鸡还是先有蛋的问题,但tf.function认为这是无效的。如果您想拥有该变量并为其分配新值,请执行以下操作:

@tf.function
def increment(self):
    self.a.assign(self.a + 1)
    return self.a
Run Code Online (Sandbox Code Playgroud)

或者只是他的:

@tf.function
def increment(self):
    self.a.assign_add(1)
    return self.a
Run Code Online (Sandbox Code Playgroud)