我正在将代码转换为 Tensorflow v2,但不断收到以下错误:
AssertionError:调用了引用已删除变量的函数。这可能意味着函数局部变量已创建,但未在程序的其他位置引用。这通常是一个错误;考虑在第一次调用时将变量存储在对象属性中。
这是重现错误的最小示例
import tensorflow as tf
class TEST:
def __init__(self, a=1):
self.a = tf.Variable(a)
@tf.function
def increment(self):
self.a = self.a + 1
return self.a
tst = TEST()
tst.increment()
Run Code Online (Sandbox Code Playgroud)
我应该如何解决这个问题?
当你这样做时:
self.a = self.a + 1
Run Code Online (Sandbox Code Playgroud)
您将使用该操作的结果覆盖 中的引用self.a,该引用最初与上面创建的变量相关联。您不会更新 TensorFlow 变量的值,而只是替换 Python 引用。您正在创建的新张量( 的结果self.a + 1)反过来在其计算中使用该变量。问题是,那一刻self.a被覆盖了,变量被遗忘了,不能再使用了。这有点像先有鸡还是先有蛋的问题,但tf.function认为这是无效的。如果您想拥有该变量并为其分配新值,请执行以下操作:
@tf.function
def increment(self):
self.a.assign(self.a + 1)
return self.a
Run Code Online (Sandbox Code Playgroud)
或者只是他的:
@tf.function
def increment(self):
self.a.assign_add(1)
return self.a
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1171 次 |
| 最近记录: |