在 Keras 中编写自定义 MSE 损失函数

Nat*_*ens 3 python loss mse python-2.7 keras

我正在尝试在 Keras 中创建图像去噪 ConvNet,并且我想创建自己的损失函数。我希望它将噪声图像作为输入并将噪声作为输出。这个损失函数非常类似于 MSE 损失,但这将使我的网络学会去除干净的图像,而不是从输入的噪声图像中去除噪声。

我想用 y 噪声图像、x 干净图像和 R(y) 预测图像来实现损失函数:

我要实现的损失函数

我尝试自己制作它,但我不知道如何让丢失的图像访问我的嘈杂图像,因为它一直在变化。

def residual_loss(noisy_img):
  def loss(y_true, y_pred):
    return np.mean(np.square(y_pred - (noisy_img - y_true), axis=-1)
return loss
Run Code Online (Sandbox Code Playgroud)

基本上,我需要做的是这样的:

input_img = Input(shape=(None,None,3))

c1 = Convolution2D(64, (3, 3))(input_img)
a1 = Activation('relu')(c1)

c2 = Convolution2D(64, (3, 3))(a1)
a2 = Activation('relu')(c2)

c3 = Convolution2D(64, (3, 3))(a2)
a3 = Activation('relu')(c3)

c4 = Convolution2D(64, (3, 3))(a3)
a4 = Activation('relu')(c4)

c5 = Convolution2D(3, (3, 3))(a4)
out = Activation('relu')(c5)

model = Model(input_img, out)
model.compile(optimizer='adam', loss=residual_loss(input_img))
Run Code Online (Sandbox Code Playgroud)

但如果我尝试这个,我会得到:

 IndexError: tuple index out of range
Run Code Online (Sandbox Code Playgroud)

我能做些什么 ?

Dan*_*ler 5

由于在损失函数中使用“输入”是很不寻常的(它不是为此目的),我认为值得一提的是:

损失函数的作用并不是分离噪声。损失函数只是衡量“你离正确有多远”。

您的模型将分离事物,并且您期望从模型中得到的结果是y_true

您应该使用常规损失X_training = noisy imagesY_training = noises


那是说...

您可以在损失函数之外创建一个张量noisy_img并将其存储。损失函数内的所有操作都必须是张量函数,因此请使用keras 后端

import keras.backend as K

noisy_img = K.variable(X_training) #you must do this for each bach
Run Code Online (Sandbox Code Playgroud)

但是您必须考虑批次大小,该变量位于损失函数之外,需要您每个 epoch 只适合一个批次

def loss(y_true,y_pred):
    return K.mean(K.square(y_pred-y_true) - K.square(y_true-noisy_img))
Run Code Online (Sandbox Code Playgroud)

每个时期训练一批:

for batch in range(0,totalSamples,size):
    noisy_img = K.variable(X_training[batch:size])
    model.fit(X_training[batch:size],Y_training[batch:size], batch_size=size)
Run Code Online (Sandbox Code Playgroud)

如果仅使用均方误差,请按如下方式组织数据:

originalImages = loadYourImages() #without noises
Y_training = createRandomNoises() #without images

X_training = addNoiseToImages(originalImages,Y_training)
Run Code Online (Sandbox Code Playgroud)

现在你只需使用“mse”,或任何其他内置损失。

model.fit(X_training,Y_training,....)
Run Code Online (Sandbox Code Playgroud)