Tensorflow 数据集 API - .from_tensor_slices() / .from_tensor() - 无法创建内容大于 2gb 的张量原型

Sol*_*MUC 5 python pipeline python-3.x tensorflow tensorflow-datasets

所以我想使用 Dataset API 来批处理我的大数据集(~8GB),因为我在使用我的 GPU 时遇到了大量空闲时间,因为我使用 feed_dict 将数据从 python 传递到 Tensorflow。

当我按照此处提到的教程进行操作时:

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/5_DataManagement/tensorflow_dataset_api.py

运行我的简单代码时:

one_hot_dataset = np.load("one_hot_dataset.npy")
dataset = tf.data.Dataset.from_tensor_slices(one_hot_dataset)
Run Code Online (Sandbox Code Playgroud)

我收到 TensorFlow 1.8 和 Python 3.5 的错误消息:

Traceback (most recent call last):

  File "<ipython-input-17-412a606c772f>", line 1, in <module>
    dataset = tf.data.Dataset.from_tensor_slices((one_hot_dataset))

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 235, in from_tensor_slices
    return TensorSliceDataset(tensors)

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1030, in __init__
    for i, t in enumerate(nest.flatten(tensors))

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1030, in <listcomp>
    for i, t in enumerate(nest.flatten(tensors))

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1014, in convert_to_tensor
    as_ref=False)

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1104, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 235, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 214, in constant
    value, dtype=dtype, shape=shape, verify_shape=verify_shape))

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/tensor_util.py", line 496, in make_tensor_proto
    "Cannot create a tensor proto whose content is larger than 2GB.")

ValueError: Cannot create a tensor proto whose content is larger than 2GB.
Run Code Online (Sandbox Code Playgroud)

我该如何解决这个问题?我认为原因很明显,但是 tf 开发人员通过将输入数据限制为 2GB 是怎么想的?!?我真的无法理解这种理性以及处理更大数据集时的解决方法是什么?

我用谷歌搜索了很多,但找不到任何类似的错误消息。当我使用 numpy 数据集的 FITFH 时,上述步骤没有任何问题。

我不知何故需要告诉 TensorFlow 我实际上将逐批加载数据,并且可能想要预取几批以保持我的 GPU 忙碌。但它似乎试图一次加载整个 numpy 数据集。那么使用 Dataset API 有什么好处,因为我可以通过简单地尝试将我的 numpy 数据集作为 tf.constant 加载到 TensorFlow 图中来重现此错误,这显然不适合并且我收到 OOM 错误。

提示和故障排除提示表示赞赏!

iga*_*iga 4

tf.data用户指南 ( https://www.tensorflow.org/guide/datasets ) 的“使用 NumPy 数组”部分中解决了此问题。

基本上,创建一个dataset.make_initializable_iterator()迭代器并在运行时提供数据。

如果由于某种原因这不起作用,您可以将数据写入文件或从 Python 生成器创建数据集(https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator),您可以在其中输入任意 Python 代码,包括对 numpy 数组进行切片并生成切片。