Keras 模型中的“predict_step”禁用了急切执行

Ram*_*mon 1 machine-learning keras tensorflow tf.keras tensorflow2.0

predict_step为什么张量流在a 函数内禁用急切执行tf.keras.Model?也许我弄错了,但这里有一个例子:

from __future__ import annotations
from functools import wraps
import tensorflow as tf

def print_execution(func):
    @wraps(func)
    def wrapper(self: SimpleModel, data):
        print(tf.executing_eagerly())  # Prints False
        return func(self, data)
    return wrapper

class SimpleModel(tf.keras.Model):
    def __init__(self):
        super().__init__()

    def call(self, inputs, training=None, mask=None):
        return inputs

    @print_execution
    def predict_step(self, data):
        return super().predict_step(data)

if __name__ == "__main__":
    x = tf.random.uniform((2, 2))
    print(tf.executing_eagerly())  # Prints True
    model = SimpleModel()
    pred = model.predict(x)
Run Code Online (Sandbox Code Playgroud)

这是预期的行为吗?有没有办法强制predict_step以急切模式运行?

M.I*_*nat 5

如果您想predict_step以 eager 模式运行该函数,可以按如下方式执行。请注意,它将把所有内容设置为 eager 模式。

import tensorflow as tf
tf.config.run_functions_eagerly(True)
Run Code Online (Sandbox Code Playgroud)

通常tf.function处于Graph模式。使用上面的语句,它们也可以设置为Eager模式src

根据您的评论,据我所知,如果您run_eagerly在编译模型时设置,应该不会有任何差异。这是来自官方声明src - model.compile

run_eagerly:布尔。默认为 False。如果为 True,则该模型的逻辑将不会包装在tf. 功能。建议将其保留为“无”,除非您的模型无法在 tf . 功能


关于您的第一个查询,为什么要在a 的函数TensorFlow内禁用急切执行?predict_steptf.keras.Model

主要原因之一是提供模型的最佳性能。而且它不仅与predict_step,而且还train_steptest_step。基本上tf. keras模型被编译成静态图。为了使它们以 Eager 模式运行,需要执行上述方法。但请注意,在这种情况下使用 eager 模式可能会减慢您的训练速度。对于良好的集体,tf. keras模型以图形模式编译。