Ste*_*ano 5 python tensorflow tensorflow-datasets
我正在尝试从 numpy 数组开始在 tensorflow 1.14 中创建一个 Dataset 对象(我有一些无法为这个特定项目更改的遗留代码),但是每次我尝试时我都会将所有内容复制到我的图表上,因此当我创建了一个很大的事件日志文件(在这种情况下为 719 MB)。
最初我尝试使用这个函数“tf.data.Dataset.from_tensor_slices()”,但它不起作用,然后我读到这是一个常见问题,有人建议我尝试使用生成器,因此我尝试使用以下代码,但是我又得到了一个巨大的事件文件(又是 719 MB)
def fetch_batch(x, y, batch):
i = 0
while i < batch:
yield (x[i,:,:,:], y[i])
i +=1
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255
training_dataset = tf.data.Dataset.from_generator(fetch_batch,
args=[images, np.int32(labels), batch_size], output_types=(tf.float32, tf.int32),
output_shapes=(tf.TensorShape(features_shape), tf.TensorShape(labels_shape)))
file_writer = tf.summary.FileWriter("/content", graph=tf.get_default_graph())
Run Code Online (Sandbox Code Playgroud)
我知道在这种情况下我可以使用 tensorflow_datasets API 并且它会更容易,但这是一个更普遍的问题,它涉及如何创建数据集,而不仅仅是使用 mnist 。你能向我解释我做错了什么吗?谢谢
我猜这是因为你正在args使用from_generator. 这肯定会将提供的内容放入args图表中。
您可以做的是定义一个函数,该函数将返回一个生成器,该生成器将迭代您的集合,例如(尚未测试):
def data_generator(images, labels):
def fetch_examples():
i = 0
while True:
example = (images[i], labels[i])
i += 1
i %= len(labels)
yield example
return fetch_examples
Run Code Online (Sandbox Code Playgroud)
这将在你的例子中给出:
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255
training_dataset = tf.data.Dataset.from_generator(data_generator(images, labels), output_types=(tf.float32, tf.int32),
output_shapes=(tf.TensorShape(features_shape), tf.TensorShape(labels_shape))).batch(batch_size)
file_writer = tf.summary.FileWriter("/content", graph=tf.get_default_graph())
Run Code Online (Sandbox Code Playgroud)
请注意,我更改fetch_batch为,fetch_examples因为您可能想使用数据集实用程序进行批处理 ( .batch)。
| 归档时间: |
|
| 查看次数: |
1363 次 |
| 最近记录: |