Keras使用tf.data.Dataset而不是numpy数组预测循环内存泄漏

Dav*_*rks 5 python keras tensorflow

predict当使用a tf.data.Dataset来馈送模型时,在遍历Keras模型函数时会遇到内存泄漏并降低性能,但是在用numpy数组馈入模型时却不会。

有谁了解导致此问题的原因和/或如何解决此问题?

最小的可复制代码段(可复制/粘贴可运行):

import tensorflow as tf
import numpy as np
import time

SIZE = 5000

inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)

model = tf.keras.Model(inputs=inp, outputs=x)

np_data = np.random.rand(1, SIZE)
ds = tf.data.Dataset.from_tensor_slices(np_data).batch(1).repeat()

debug_time = time.time()
while True:
    model.predict(x=ds, steps=1)
    print('Processing {:.2f}'.format(time.time() - debug_time))
    debug_time = time.time()
Run Code Online (Sandbox Code Playgroud)

结果:预测循环定时从每次迭代开始约0.04s,在一两分钟之内达到约0.5s,并且过程内存从几百MB继续增加到接近GB。


交换出tf.data.Dataset一个等效的numpy数组,运行时间始终为〜0.01s。

工作案例代码段(可复制/粘贴可运行):

import tensorflow as tf
import numpy as np
import time

SIZE = 5000

inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)

model = tf.keras.Model(inputs=inp, outputs=x)

np_data = np.random.rand(1, SIZE)

debug_time = time.time()
while True:
    model.predict(x=np_data)  # using numpy array directly
    print('Processing {:.2f}'.format(time.time() - debug_time))
    debug_time = time.time()
Run Code Online (Sandbox Code Playgroud)

相关讨论:


附加信息:

  • 通过传入迭代器而不是数据集对象,我可以将性能下降的速度降低大约10倍。我在training_utils.py:1314Keras代码中注意到,每个调用都要创建一个迭代器来进行预测。

TF 1.14.0

Dav*_*rks 3

问题的根源似乎是 Keras 在每个predict循环中创建数据集操作。请注意,training_utils.py:1314在每个预测循环中都会创建一个数据集迭代器。

通过传入迭代器可以降低问题的严重性,并且通过传入迭代器get_next()张量可以完全解决问题。

我已在 Tensorflow Github 页面上发布了该问题:https://github.com/tensorflow/tensorflow/issues/30448

这是解决方案,此示例使用 TF 数据集以恒定时间运行,只是不能传入数据集对象:

import tensorflow as tf
import numpy as np
import time

SIZE = 5000

inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)

model = tf.keras.Model(inputs=inp, outputs=x)

np_data = np.random.rand(1, SIZE)
ds = tf.data.Dataset.from_tensor_slices(np_data).batch(1).repeat()
it = tf.data.make_one_shot_iterator(ds)
tensor = it.get_next()

debug_time = time.time()
while True:
    model.predict(x=tensor, steps=1)
    print('Processing {:.2f}'.format(time.time() - debug_time))
    debug_time = time.time()
Run Code Online (Sandbox Code Playgroud)