使用tf.function的Tensorflow 2.0模型非常慢,并且每次火车数量变化时都会重新编译。渴望的速度快大约4倍

mat*_*ick 9 keras tensorflow2.0

我有从未编译的keras代码构建的模型,并且正在尝试通过自定义训练循环运行它们。

TF 2.0急切(默认)代码在CPU(笔记本电脑)上运行约30秒钟。当我用包装的tf.function调用方法创建一个keras模型时,它运行的速度非常慢,而且启动时间似乎很长,尤其是“第一次”。

例如,在tf.function代码中,对10个样本的初始训练花费40s,而对10个样本的后续训练花费2s。

在20个样本上,初始花费50s,后续花费4s。

第一次采样1个样本需要2秒钟,后续过程需要200毫秒。

如此看来,每次火车呼叫都在创建一个新图,其中复杂度随火车数量而增加!

我只是在做这样的事情:

@tf.function
def train(n=10):
    step = 0
    loss = 0.0
    accuracy = 0.0
    for i in range(n):
        step += 1
        d, dd, l = train_one_step(model, opt, data)
        tf.print(dd)
        with tf.name_scope('train'):
            for k in dd:
                tf.summary.scalar(k, dd[k], step=step)
        if tf.equal(step % 10, 0):
            tf.print(dd)
    d.update(dd)
    return d
Run Code Online (Sandbox Code Playgroud)

根据示例,模型keras.model.Model使用@tf.function装饰call方法。

nes*_*uno 12

我分析了这种行为@tf.function在这里使用Python的本机类型

简而言之:的设计tf.function不会自动将Python本机类​​型装箱到tf.Tensor定义良好的对象dtype

如果您的函数接受一个tf.Tensor对象,则在第一次调用时将对该函数进行分析,然后将图形建立并与该函数关联。在每一个非第一次呼叫中,如果dtype所述的tf.Tensor对象匹配,图形被再利用。

但是在使用Python本机类​​型的情况下,每次以不同的值调用函数时都会构建graphg 。

简而言之:tf.Tensor如果打算使用,请设计代码以在所有地方使用,而不是在Python变量中使用@tf.function

tf.function不是可以神奇地加速在急切模式下运行良好的功能的包装器;是一个包装,需要设计eager函数(主体,输入参数,dytpes),以了解创建图形后将发生的情况,从而获得真正的提速效果。

  • 传递模型(即 keras 对象)、tf.data.dataset 或任何 tf.* 对象根本不是问题。仅当您传递 Python 本机类型时,性能才会下降 (3认同)
  • 这很棒...我猜应该是对文档的一个大警告。如果有一个,我肯定会错过的。 (2认同)
  • 谢谢,@nessuno。我知道您指的是本机*数字*类型,但我还想补充一点,即使列表也被视为Python中的本机类型,张量列表也可以正常工作。 (2认同)