Tensorflow Dataset.from_tensor_slices 耗时太长

nik*_*iko 7 python numpy tensorflow tensorflow-datasets

我有以下代码:

data = np.load("data.npy")
print(data) # Makes sure the array gets loaded in memory
dataset = tf.contrib.data.Dataset.from_tensor_slices((data))
Run Code Online (Sandbox Code Playgroud)

该文件"data.npy"为 3.3 GB。使用 numpy 读取文件需要几秒钟,但是创建 tensorflow 数据集对象的下一行需要很长时间才能执行。这是为什么?它在幕后做什么?

Jul*_*yes 5

引用这个答案

np.loadof anpz只返回文件加载器,而不是实际数据。这是一个“惰性加载器”,仅在访问时加载特定数组。

这就是为什么它很快。

编辑 1:为了进一步扩展这个答案,来自tensorflow 文档的另一个引用:

如果所有输入数据都适合内存,Dataset从它们创建 的最简单方法是将它们转换为tf.Tensor对象并使用Dataset.from_tensor_slices().

这适用于小数据集,但会浪费内存——因为数组的内容将被多次复制——并且可能会遇到 tf.GraphDef 协议缓冲区的 2GB 限制。

该链接还显示了如何有效地做到这一点。