如何使用 tf.data.Dataset.from_generator 进行批处理?我需要修改发电机吗

Tar*_*hra 6 tensorflow tensorflow-datasets tensorflow2.0

我正在使用该batch(8)函数,它修改形状并添加批次尺寸,但每批次仅获取一张图像。下面是我的代码:-

import cv2
import numpy as np
import os
import tensorflow as tf
import random

folder_path = "./real/"
files = os.listdir(folder_path)

def get_image():
    index = random.randint(0,len(files)-1)
    img = cv2.imread(folder_path+files[index])
    img = cv2.resize(img,(128,128))
    img = img/255.
    #More complex transformation
    yield img

dset = tf.data.Dataset.from_generator(get_image,(tf.float32)).batch(8)

for img in dset:
    print(img.shape)
    break
Run Code Online (Sandbox Code Playgroud)

即使使用batch(8),输出仍然是(1, 128, 128, 3)。我是否需要修改生成器来手动创建批次?另外,如何将其包装在tensorflow中的生成器中,使其运行得更快?

Edw*_*ong 8

这是因为你的收益只需要一张图像,你应该在循环中收益,这是一个例子

def get_image():
   for file in files:
      img = cv2.imread(folder_path + file)
      img = cv2.resize(img, (128, 128))
      img = img / 255.

      yield img # Your supposed to yield in a loop

dataset = tf.data.Dataset.from_generator(get_image, output_shapes=(128, 128), output_types=(tf.float32))

next(iter(dataset.batch(8))).shape

# TensorShape([8, 128, 128])
Run Code Online (Sandbox Code Playgroud)

  • 你好@EdwinCheong,这是有效的,但是如果数据集非常大,并且包含图像,似乎 get_image() 会立即在大数据集上运行,然后再进行批量处理并传递训练。那么准备好的数据岂不是RAM就满了? (2认同)