TensorFlow将图像张量调整为动态形状

mac*_*ion 7 python image shape tensorflow

我试图用TensorFlow读取图像分类问题的一些图像输入.

当然,我这样做tf.image.decode_jpeg(...).我的图像大小可变,因此我无法为图像张量指定固定的形状.

但我需要根据实际尺寸来缩放图像.具体来说,我想以保持纵横比的方式将短边缩放到固定值和长边.

我可以通过获得某个图像的实际形状shape = tf.shape(image).我也能够为新的长边做计算

shape = tf.shape(image)
height = shape[0]
width = shape[1]
new_shorter_edge = 400
if height <= width:
    new_height = new_shorter_edge
    new_width = ((width / height) * new_shorter_edge)
else:
    new_width = new_shorter_edge
    new_height = ((height / width) * new_shorter_edge)
Run Code Online (Sandbox Code Playgroud)

我现在的问题是,我无法通过new_height,并new_widthtf.image.resize_images(...)因为其中一个是张量和resize_images预计整数作为高度和宽度的投入.

有没有办法"拉出"张量的整数,还是有其他方法可以用TensorFlow完成我的任务?

提前致谢.


编辑

因为我也有一些其他的问题tf.image.resize_images,这里是为我工作的代码:

shape = tf.shape(image)
height = shape[0]
width = shape[1]
new_shorter_edge = tf.constant(400, dtype=tf.int32)

height_smaller_than_width = tf.less_equal(height, width)
new_height_and_width = tf.cond(
    height_smaller_than_width,
    lambda: (new_shorter_edge, _compute_longer_edge(height, width, new_shorter_edge)),
    lambda: (_compute_longer_edge(width, height, new_shorter_edge), new_shorter_edge)
)

image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(image, tf.pack(new_height_and_width))
image = tf.squeeze(image, [0])
Run Code Online (Sandbox Code Playgroud)

mrr*_*rry 6

这样做的方法是使用(目前是实验性的,但在下一版本中可用)tf.cond()*运算符.此运算符能够测试在运行时计算的值,并根据该值执行两个分支之一.

shape = tf.shape(image)
height = shape[0]
width = shape[1]
new_shorter_edge = 400
height_smaller_than_width = tf.less_equal(height, width)

new_shorter_edge = tf.constant(400)
new_height, new_width = tf.cond(
    height_smaller_than_width,
    lambda: new_shorter_edge, (width / height) * new_shorter_edge,
    lambda: new_shorter_edge, (height / width) * new_shorter_edge)
Run Code Online (Sandbox Code Playgroud)

现在您有Tensor值,new_height并且new_width将在运行时获取适当的值.


*要在当前发布的版本中访问操作员,您需要导入以下内容:

from tensorflow.python.ops import control_flow_ops
Run Code Online (Sandbox Code Playgroud)

...然后用control_flow_ops.cond()而不是tf.cond().