Lau*_*Lau 6 image scipy tensorflow
我刚读过这篇文章.文章说,tensorflow的resize算法有一些bug.现在我想用scipy.misc.imresize
而不是tf.image.resize_images
.我想知道实现scipy resize算法的最佳方法是什么.
让我们考虑以下层:
def up_sample(input_tensor, new_height, new_width):
_up_sampled = tf.image.resize_images(input_tensor, [new_height, new_width])
_conv = tf.layers.conv2d(_up_sampled, 32, [3,3], padding="SAME")
return _conv
Run Code Online (Sandbox Code Playgroud)
如何在此图层中使用scipy算法?
编辑:
一个例子可以是这个功能:
input_tensor = tf.placeholder("float32", [10, 200, 200, 8])
output_shape = [32, 210, 210, 8]
def up_sample(input_tensor, output_shape):
new_array = np.zeros(output_shape)
for batch in range(input_tensor.shape[0]):
for channel in range(input_tensor.shape[-1]):
new_array[batch, :, :, channel] = misc.imresize(input_tensor[batch, :, :, channel], output_shape[1:3])
Run Code Online (Sandbox Code Playgroud)
但显然scipy会引发一个ValueError,即tf.Tensor对象的形状不正确.我读到在tf.Session期间,Tensors可以作为numpy数组访问.如何仅在会话期间使用scipy函数并在创建协议缓冲区时省略执行?
是否有比循环所有批次和渠道更快的方式?
一般来说,您需要的工具是tf.map_fn
和的组合tf.py_func
。
tf.py_func
允许您将标准 python 函数包装到插入到图形中的张量流操作中。tf.map_fn
当函数无法对整个批次 \xe2\x80\x94 进行操作(图像函数通常会出现这种情况)时,允许您在批次样本上重复调用该函数。在目前的情况下,我可能会建议在scipy.ndimage.zoom
它可以直接在 4D 张量上操作的基础上使用,这会让事情变得更简单。另一方面,它需要输入缩放系数,而不是大小,因此我们需要计算它们。
import tensorflow as tf\n\nsess = tf.InteractiveSession()\n\n# unimportant -- just a way to get an input tensor\nbatch_size = 13\nim_size = 7\nnum_channel=5\nx = tf.eye(im_size)[None,...,None] + tf.zeros((batch_size, 1, 1, num_channel))\nnew_size = 17\n\nfrom scipy import ndimage\nnew_x = tf.py_func(\n lambda a: ndimage.zoom(a, (1, new_size/im_size, new_size/im_size, 1)),\n [x], [tf.float32], stateful=False)[0]\nprint(new_x.eval().shape)\n# (13, 17, 17, 5)\n
Run Code Online (Sandbox Code Playgroud)\n\n您可以使用其他函数(例如 OpenCV's cv2.resize
、 Scikit-image's transform.image
、 Scipy's misc.imresize
),但没有一个可以直接对 4D 张量进行操作,因此使用起来更加冗长。如果您想要除此之外的插值,您可能仍然想使用它们zoom
如果您想要除\ 的基于样条线的插值
但是,请注意以下事项:
\n\nPython 函数在主机上执行。因此,如果您在图形卡等设备上执行图形,则需要停止,将张量复制到主机内存,调用您的函数,然后将结果复制回设备上。如果内存传输很重要,这可能会完全破坏您的计算时间。
梯度不通过 python 函数传递。如果您的节点用于网络的升级部分,上游层将不会收到任何梯度(或者只有部分梯度,如果您有跳过连接),这会损害您的训练。
出于这两个原因,我建议仅在 CPU 上进行预处理并且不使用梯度时对输入应用这种重采样。
\n\ntf.image.resize_image
如果您确实想使用这个高档节点在设备上进行训练,那么我认为除了坚持使用 buggy或编写自己的节点之外,别无选择。
归档时间: |
|
查看次数: |
361 次 |
最近记录: |