Men*_*rel 5 deep-learning keras tensorflow semantic-segmentation
给定批量 RGB 图像作为输入,shape=(batch_size, width, height, 3)
多类目标表示为 one-hot,shape=(batch_size, width, height, n_classes)
以及最后一层具有 softmax 激活的模型(Unet、DeepLab)。
我正在寻找 kera/tensorflow 中的加权分类交叉熵损失函数。
class_weight中的论点似乎fit_generator不起作用,我在这里或https://github.com/keras-team/keras/issues/2115中没有找到答案。
def weighted_categorical_crossentropy(weights):
# weights = [0.9,0.05,0.04,0.01]
def wcce(y_true, y_pred):
# y_true, y_pred shape is (batch_size, width, height, n_classes)
loos = ?...
return loss
return wcce
Run Code Online (Sandbox Code Playgroud)
我来回答我的问题:
def weighted_categorical_crossentropy(weights):
# weights = [0.9,0.05,0.04,0.01]
def wcce(y_true, y_pred):
Kweights = K.constant(weights)
if not K.is_tensor(y_pred): y_pred = K.constant(y_pred)
y_true = K.cast(y_true, y_pred.dtype)
return K.categorical_crossentropy(y_true, y_pred) * K.sum(y_true * Kweights, axis=-1)
return wcce
Run Code Online (Sandbox Code Playgroud)
用法:
loss = weighted_categorical_crossentropy(weights)
optimizer = keras.optimizers.Adam(lr=0.01)
model.compile(optimizer=optimizer, loss=loss)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
8690 次 |
| 最近记录: |