如何将图像从像素转换为单热编码?

Kri*_*ves 5 python tensorflow

我有一个正在使用Tensorflow加载的PNG图像:

image = tf.io.decode_png(tf.io.read_file(path), channels=3)
Run Code Online (Sandbox Code Playgroud)

图像包含与查找匹配的像素,如下所示:

image_colors = [
    (0, 0, 0),        # black
    (0.5, 0.5, 0.5),  # grey
    (1, 0.5, 1),      # pink
]
Run Code Online (Sandbox Code Playgroud)

如何转换它,以便输出将像素映射为一热编码,其中热成分将是匹配颜色?

rvi*_*nas 5

为了方便起见,让我假设其中的所有值image_colors都在[0, 255]

image_colors = [
    (0, 0, 0),  # black
    (127, 127, 127),  # grey
    (255, 127, 255),  # pink
]
Run Code Online (Sandbox Code Playgroud)

我的方法将像素映射为一键热值,如下所示:

# Create a "color reference" tensor from image_colors
color_reference = tf.cast(tf.constant(image_colors), dtype=tf.uint8)

# Load the image and obtain tensor with one-hot values
image = tf.io.decode_png(tf.io.read_file(path), channels=3)
comp = tf.equal(image[..., None, :], color_reference)
one_hot = tf.cast(tf.reduce_all(comp, axis=-1), dtype=tf.float32)
Run Code Online (Sandbox Code Playgroud)

请注意,您可以轻松添加新颜色,image_colors而无需更改TF实现。此外,这还假设中的所有像素image都在中image_colors。如果不是这种情况,则可以为每种颜色定义一个范围,然后使用其他tf.math操作(例如tf.greatertf.less)代替tf.equal