反射填充 Conv2D

Aki*_*iko 6 python padding keras zero-padding convolutional-neural-network

我正在使用 keras 构建一个用于图像分割的卷积神经网络,我想使用“反射填充”而不是“相同”填充,但我找不到在 keras 中做到这一点的方法。

inputs = Input((num_channels, img_rows, img_cols))
conv1=Conv2D(32,3,padding='same',kernel_initializer='he_uniform',data_format='channels_first')(inputs)
Run Code Online (Sandbox Code Playgroud)

有没有办法实现反射层并将其插入到 keras 模型中?

jee*_*a_v 7

上面接受的答案在当前的 Keras 版本中不起作用。这是有效的版本:

class ReflectionPadding2D(Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, s):
        """ If you are using "channels_last" configuration"""
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')
Run Code Online (Sandbox Code Playgroud)

  • 这个实现很有帮助而且干净,感谢分享! (2认同)

Aki*_*iko 3

找到解决方案了!我们只需创建一个新类,将一个层作为输入,并使用张量流预定义函数来完成它。

import tensorflow as tf
from keras.engine.topology import Layer
from keras.engine import InputSpec

class ReflectionPadding2D(Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def get_output_shape_for(self, s):
        """ If you are using "channels_last" configuration"""
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')

# a little Demo
inputs = Input((img_rows, img_cols, num_channels))
padded_inputs= ReflectionPadding2D(padding=(1,1))(inputs)
conv1 = Conv2D(32, 3, padding='valid', kernel_initializer='he_uniform',
               data_format='channels_last')(padded_inputs)
Run Code Online (Sandbox Code Playgroud)