使用Tensorflow数据集创建RLE(行程编码)蒙版

Ale*_*lex 5 run-length-encoding tensorflow tensorflow-datasets

我一直在尝试使用Tensorflow数据集,但无法弄清楚如何有效地创建RLE蒙版。仅供参考,我正在使用Kaggle的空客船舶检测挑战赛中的dat:https://www.kaggle.com/c/airbus-ship-detection/data

我知道我的RLE解码功能可以从以下一种内核中工作(借用):

def rle_decode(mask_rle, shape=(768, 768)):
'''
mask_rle: run-length as string formated (start length)
shape: (height,width) of array to return
Returns numpy array, 1 - mask, 0 - background
'''
if not isinstance(mask_rle, str):
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    return img.reshape(shape).T

s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
for lo, hi in zip(starts, ends):
    img[lo:hi] = 1
return img.reshape(shape).T
Run Code Online (Sandbox Code Playgroud)

....但它似乎在管道中不能很好地发挥作用:

list_ds = tf.data.Dataset.list_files(train_paths_abs)
ds = list_ds.map(parse_img)
Run Code Online (Sandbox Code Playgroud)

使用以下解析函数,一切正常:

def parse_img(file_path,new_size=[128,128]):    
    img_content = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img_content)
    img = tf.image.convert_image_dtype(img, tf.float32)    
    img = tf.image.resize(img,new_size)
    return img
Run Code Online (Sandbox Code Playgroud)

但是如果我戴上口罩,事情就会变得很糟糕:

def rle_decode(mask_rle, shape=(768, 768)):
'''
mask_rle: run-length as string formated (start length)
shape: (height,width) of array to return
Returns numpy array, 1 - mask, 0 - background
'''
if not isinstance(mask_rle, str):
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    return img.reshape(shape).T

s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
for lo, hi in zip(starts, ends):
    img[lo:hi] = 1
return img.reshape(shape).T
Run Code Online (Sandbox Code Playgroud)

尽管我的parse_img功能工作正常(我已经在一个样本上对其进行了检查,但每次运行需要271 µs±67.9 µs)。list_ds.map挂起之前,此步骤需要花费永久时间(> 5分钟)。我不知道怎么了,这让我发疯! 任何的想法?

jde*_*esa 5

You can rewrite the function rle_decode with like this (here I do not do the final transposition to keep it more general, but you can do it later):

import tensorflow as tf

def rle_decode_tf(mask_rle, shape):
    shape = tf.convert_to_tensor(shape, tf.int64)
    size = tf.math.reduce_prod(shape)
    # Split string
    s = tf.strings.split(mask_rle)
    s = tf.strings.to_number(s, tf.int64)
    # Get starts and lengths
    starts = s[::2] - 1
    lens = s[1::2]
    # Make ones to be scattered
    total_ones = tf.reduce_sum(lens)
    ones = tf.ones([total_ones], tf.uint8)
    # Make scattering indices
    r = tf.range(total_ones)
    lens_cum = tf.math.cumsum(lens)
    s = tf.searchsorted(lens_cum, r, 'right')
    idx = r + tf.gather(starts - tf.pad(lens_cum[:-1], [(1, 0)]), s)
    # Scatter ones into flattened mask
    mask_flat = tf.scatter_nd(tf.expand_dims(idx, 1), ones, [size])
    # Reshape into mask
    return tf.reshape(mask_flat, shape)
Run Code Online (Sandbox Code Playgroud)

A small test (TensorFlow 2.0):

mask_rle = '1 2 4 3 9 4 15 5'
shape = [4, 6]
# Original NumPy function
print(rle_decode(mask_rle, shape))
# [[1 0 0 1]
#  [1 0 0 0]
#  [0 1 1 0]
#  [1 1 1 0]
#  [1 1 1 0]
#  [1 1 1 0]]
# TensorFlow function (transposing is done out of the function)
tf.print(tf.transpose(rle_decode_tf(mask_rle, shape)))
# [[1 0 0 1]
#  [1 0 0 0]
#  [0 1 1 0]
#  [1 1 1 0]
#  [1 1 1 0]
#  [1 1 1 0]]
Run Code Online (Sandbox Code Playgroud)