如何从tfrecords目录创建tf.data.dataset?

Sia*_*ash 2 tensorflow tensorflow-datasets

我的数据集有不同的目录,每个目录对应一个类.每个目录中有不同数量的.tfrecords.我的问题是如何从每个目录中采样5个图像(每个.tfrecord文件对应一个图像)?我的另一个问题是如何对这些目录中的5个进行采样,然后从每个目录中采样5个图像.

我只想用tf.data.dataset来做.所以我希望有一个数据集,我从中得到一个迭代器,iterator.next()给我一批25个图像,包含5个类的5个样本.

mrr*_*rry 10

编辑:如果类的数量大于5,那么您可以使用新的tf.contrib.data.sample_from_datasets()API(目前可用,tf-nightly并将在TensorFlow 1.9中提供).

directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", ...]

CLASSES_PER_BATCH = 5
EXAMPLES_PER_CLASS_PER_BATCH = 5
BATCH_SIZE = CLASSES_PER_BATCH * EXAMPLES_PER_CLASS_PER_BATCH
NUM_CLASSES = len(directories)


# Build one dataset per class.
per_class_datasets = [
    tf.data.TFRecordDataset(tf.data.Dataset.list_files(d)) for d in directories]

# Next, build a dataset where each element is a vector of 5 classes to be chosen
# for a particular batch.
classes_per_batch_dataset = tf.contrib.data.Counter().map(
    lambda _: tf.random_shuffle(tf.range(NUM_CLASSES))[:CLASSES_PER_BATCH]))

# Transform the dataset of per-batch class vectors into a dataset with one
# one-hot element per example (i.e. 25 examples per batch).
class_dataset = classes_per_batch_dataset.flat_map(
    lambda classes: tf.data.Dataset.from_tensor_slices(
        tf.one_hot(classes, num_classes)).repeat(EXAMPLES_PER_CLASS_PER_BATCH))

# Use `tf.contrib.data.sample_from_datasets()` to select an example from the
# appropriate dataset in `per_class_datasets`.
example_dataset = tf.contrib.data.sample_from_datasets(per_class_datasets,
                                 class_dataset)

# Finally, combine 25 consecutive examples into a batch.
result = example_dataset.batch(BATCH_SIZE)
Run Code Online (Sandbox Code Playgroud)

如果您只有5个类,则可以为每个目录定义嵌套数据集,并使用Dataset.interleave()以下内容进行组合:

# NOTE: We're assuming that the 0th directory contains elements from class 0, etc.
directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", "class_4/*"]
directories = tf.data.Dataset.from_tensor_slices(directories)
directories = directories.apply(tf.contrib.data.enumerate_dataset())    

# Define a function that maps each (class, directory) pair to the (shuffled)
# records in those files.
def per_directory_dataset(class_label, directory_glob):
  files = tf.data.Dataset.list_files(directory_glob, shuffle=True)
  records = tf.data.TFRecordDataset(records)
  # Zip the records with their class. 
  # NOTE: This part might not be necessary if the records contain information about
  # their class that can be parsed from them.
  return tf.data.Dataset.zip(
      (records, tf.data.Dataset.from_tensors(class_label).repeat(None)))

# NOTE: The `cycle_length` and `block_length` here aren't strictly necessary,
# because the batch size is exactly `number of classes * images per class`.
# However, these arguments may be useful if you want to decouple these numbers.
merged_records = directories.interleave(per_directory_dataset,
                                        cycle_length=5, block_length=5)
merged_records = merged_records.batch(25)
Run Code Online (Sandbox Code Playgroud)