骰子系数大于 1

nao*_*ity 8 image-segmentation deep-learning keras

当我训练 UNET 时,dice coef 和 iou 有时会变得大于 1 和iou > dice,然后经过几个批次后它们又会恢复正常。如图所示。

我将它们定义如下:

def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def iou(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum(y_true_f) + K.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

def dice_loss(y_true, y_pred):
    return 1. - dice_coef(y_true, y_pred)
Run Code Online (Sandbox Code Playgroud)

我尝试添加K.abs()y_pred 但这会导致性能更差。我觉得既然输出是 sigmoid 激活的,是否添加K.abs()应该给出相同的结果?另外,正如你所看到的,我的准确性很奇怪,我一直依靠骰子来判断我的模型性能,如果有人能指出这个问题,那就更好了。

Dan*_*ler 10

我相信你的y_true图像可能不在 0 到 1 之间的范围内......你确定它们不在 0 到 255 之间吗?或者他们只有一个通道(而不是 3 个通道?)

这不应该是原因,但您使用的是批量骰子,您应该使用图像骰子:

def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.batch_flatten(y_true)
    y_pred_f = K.batch_flatten(y_pred)

    intersection = K.sum(y_true_f * y_pred_f, axis=-1)
    sums = K.sum(y_true_f, axis=-1) + K.sum(y_pred_f, axis=-1)

    return (2. * intersection + smooth) / (sums + smooth)
Run Code Online (Sandbox Code Playgroud)

通常,我用K.epsilon()“平滑”(非常小的东西)。

这同样适用于iou

def iou(y_true, y_pred, smooth=1):
    y_true_f = K.batch_flatten(y_true)
    y_pred_f = K.batch_flatten(y_pred)

    intersection = K.sum(y_true_f * y_pred_f, axis=-1)
    union = K.sum(y_true_f, axis=-1) + K.sum(y_pred_f, axis=-1) - intersection
    return (intersection + smooth) / (union + smooth)
Run Code Online (Sandbox Code Playgroud)

通道骰子的示例:

#considering shape (batch, classes, image_size, image_size)
def dice_coef(y_true, y_pred, smooth=1):

    intersection = K.sum(y_true * y_pred, axis=[2,3])
    sums = K.sum(y_true, axis=[2,3]) + K.sum(y_pred, axis=[2,3])

    dice = (2. * intersection + smooth) / (sums + smooth)
    return K.mean(dice, axis=-1)
Run Code Online (Sandbox Code Playgroud)