我应该直接返回数据集还是应该使用 one_shot 迭代器?

Val*_*Val 5 python iterator pipeline tensorflow tensorflow-datasets

我正在使用 Dataset API 构建数据管道,但是当我训练多个 GPU 并返回dataset.make_one_shot_iterator().get_next()我的输入函数时,我得到

ValueError: dataset_fn() must return a tf.data.Dataset when using a tf.distribute.Strategy
Run Code Online (Sandbox Code Playgroud)

我可以按照错误消息直接返回数据集,但我不明白iterator().get_next()它在单 GPU 和多 GPU 上训练的目的和工作原理。

...

    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size = batch_size)
    dataset = dataset.cache()

    dataset = dataset.prefetch(buffer_size=None)

    return dataset.make_one_shot_iterator().get_next()

return _input_fn
Run Code Online (Sandbox Code Playgroud)

rac*_*lim 3

tf.data与分发策略一起使用时(可以与 Keras 和tf.Estimators 一起使用),您的输入 fn 应返回tf.data.Dataset

def input_fn():
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size = batch_size)
  dataset = dataset.cache()

  dataset = dataset.prefetch(buffer_size=None)
  return dataset

...use input_fn...
Run Code Online (Sandbox Code Playgroud)

请参阅有关分发策略的文档

dataset.make_one_shot_iterator()在分发策略/更高级别的库之外非常有用,例如,如果您正在使用较低级别的库,或者调试/测试数据集。例如,您可以像这样迭代数据集的所有元素:

dataset = ...
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with tf.Session() as sess:
  while True:
    print(sess.run(get_next))
  except tf.errors.OutOfRangeError:
    break
Run Code Online (Sandbox Code Playgroud)