工人在Keras的fit_generator中是什么意思?

W. *_*am 5 python pipeline keras

我有一个大型数据集存储在tfrecord333 之类的文件中用于训练,因此我将数据分成多个文件,例如 1024 tfrecords 文件而不是一个文件。我在 tf.Dataset Api 中使用了输入管道。喜欢:

ds= ds.TFRecordsDataset(files).shuffle().repeat().shuffle().repeat()
ds = ds.prefetch(1)
Run Code Online (Sandbox Code Playgroud)

而且我有自己的生成器,可以生成batch_x, batch_y.

我的问题是代码仅在我设置workers=0in时才有效fit_generator()

每当我将其设置为大于 0 时,都会出现以下错误

ValueError: Tensor("PrefetchDataset:0", shape=(), dtype=variant) 必须与 Tensor("Iterator:0", shape=(), dtype=resource) 来自同一图。

workers =他们说,如果0 还不够,则有关这意味着什么的文档

如果为 0,将在主线程上执行生成器。

我在 github here 中发现了类似的问题,但还没有解决方案。

这里发布另一种类似的问题,但我不同,因为我使用的是 Keras 而不是张量流,而且我没有使用 with tf.Graph().as_default(). 建议有两个图而不是一个图,因此解决方案是删除tf.Graph().as_default(). 当我检查图形时,我注意到与我的输入管道相关的所有映射函数在不同的图形(子图形)中,并且它不能附加到主图形中。像下面这样:

在此处输入图片说明

我应该提到,这是一个两阶段的培训。首先,我使用基于图像的数据集构建了一个网络,并且该网络在imagenet上进行了预训练,我刚刚训练了我的分类器。数据集在hdf5文件中,可以放入内存中。在第二阶段,我在第一阶段使用经过训练的网络并将一些块附加到它上面,我的数据集这里是tfrecod文件,这就是为什么我使用 atf.Dataset API作为我的输入管道的原因。所以这个新的输入管道不存在于第一阶段的第一个图中。但没关系,我只是将预处理过的网络用作基本模式,然后向其添加不同的块。所以它是全新的模型。

而我想换工人的主要原因,因为我的 GPU 利用率始终为零,这意味着 CPU 是瓶颈,这意味着 Cpu 需要花费大量时间来提取数据。我的 GPU 总是在等待。这就是为什么训练需要很长时间,比如一个 epoch 9 小时。

任何人都可以解释错误的含义吗?