我正在尝试将一个掩码(二进制,只有一个通道)应用于RGB图像(3个通道,标准化为[0,1]).我目前的解决方案是,将RGB图像分割成它的通道,将其与掩码相乘并再次连接这些通道:
with tf.variable_scope('apply_mask') as scope:
# Output mask is in range [-1, 1], bring to range [0, 1] first
zero_one_mask = (output_mask + 1) / 2
# Apply mask to all channels.
channels = tf.split(3, 3, output_img)
channels = [tf.mul(c, zero_one_mask) for c in channels]
output_img = tf.concat(3, channels)
Run Code Online (Sandbox Code Playgroud)
然而,这似乎效率很低,特别是因为根据我的理解,这些计算都不是就地完成的.有没有更有效的方法来做到这一点?