Dim*_*ich 6 tensorflow tensorflow-datasets
考虑创建从高分辨率图像目录中采样随机小图像块的数据集的问题.Tensorflow数据集API通过构建图像名称数据集,对其进行混洗,将其映射到加载的图像,然后映射到随机裁剪的补丁,可以非常简单地实现此目的.
然而,这种天真的实现效率非常低,因为将加载和裁剪单独的高分辨率图像以生成每个补丁.理想情况下,图像可以加载一次并重新使用以生成许多补丁.
之前讨论过的一种简单方法是从图像生成多个补丁并将其展平.然而,这有太多偏差数据的不幸影响.我们希望每个培训批次都来自不同的图像.
理想情况下,我想要的是一个"随机缓存过滤器"转换,它采用底层数据集并将其N个元素缓存到内存中.它的迭代器将从缓存中返回一个随机元素.此外,使用预定义的频率,它将使用基础数据集中的新元素替换缓存中的随机元素.该过滤器将允许更快的数据访问,代价是更少的随机化和更高的内存消耗.
有这样的功能吗?
如果不是,它应该实现为新的数据集转换还是仅仅是新的迭代器?似乎只需要一个新的迭代器.有关如何创建新数据集迭代器的任何指针,理想情况下是在C++中?
Oli*_*rot 14
你应该能够用来tf.data.Dataset.shuffle实现你想要的.以下是目标的快速摘要:
您可以tf.data通过执行以下步骤来实现使用API的所有功能:
这是一个相关的代码:
filenames = ... # filenames containing the big images
num_samples = len(filenames)
# Parameters
num_patches = 100 # number of patches to extract from each image
patch_size = 32 # size of the patches
buffer_size = 50 * num_patches # shuffle patches from 50 different big images
num_parallel_calls = 4 # number of threads
batch_size = 10 # size of the batch
get_patches_fn = lambda image: get_patches(image, num_patches=num_patches, patch_size=patch_size)
# Create a Dataset serving batches of random patches in our images
dataset = (tf.data.Dataset.from_tensor_slices(filenames)
.shuffle(buffer_size=num_samples) # step 1: all the filenames into the buffer ensures good shuffling
.map(parse_fn, num_parallel_calls=num_parallel_calls) # step 2
.map(get_patches_fn, num_parallel_calls=num_parallel_calls) # step 3
.apply(tf.contrib.data.unbatch()) # unbatch the patches we just produced
.shuffle(buffer_size=buffer_size) # step 4
.batch(batch_size) # step 5
.prefetch(1) # step 6: make sure you always have one batch ready to serve
)
iterator = dataset.make_one_shot_iterator()
patches = iterator.get_next() # shape [None, patch_size, patch_size, 3]
sess = tf.Session()
res = sess.run(patches)
Run Code Online (Sandbox Code Playgroud)
函数parse_fn和get_patches定义如下:
def parse_fn(filename):
"""Decode the jpeg image from the filename and convert to [0, 1]."""
image_string = tf.read_file(filename)
# Don't use tf.image.decode_image, or the output shape will be undefined
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
# This will convert to float values in [0, 1]
image = tf.image.convert_image_dtype(image_decoded, tf.float32)
return image
def get_patches(image, num_patches=100, patch_size=16):
"""Get `num_patches` random crops from the image"""
patches = []
for i in range(num_patches):
patch = tf.random_crop(image, [patch_size, patch_size, 3])
patches.append(patch)
patches = tf.stack(patches)
assert patches.get_shape().dims == [num_patches, patch_size, patch_size, 3]
return patches
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2036 次 |
| 最近记录: |