Keras的中心损失

sle*_*l56 7 keras

我想实现在Keras [ http://ydwen.github.io/papers/WenECCV16.pdf]中解释的中心损失

我开始创建一个包含2个输出的网络,例如:

inputs = Input(shape=(100,100,3))
...
fc = Dense(100)(#previousLayer#)
softmax = Softmax(fc)
model = Model(input, output=[softmax, fc])
model.compile(optimizer='sgd', 
              loss=['categorical_crossentropy', 'center_loss'],
              metrics=['accuracy'], loss_weights=[1., 0.2])
Run Code Online (Sandbox Code Playgroud)

首先,这样做,这是继续进行的好方法吗?

其次,我不知道如何在keras中实现center_loss.Center_loss看起来像均方误差,但它不是将值与固定标签进行比较,而是将值与每次迭代时更新的数据进行比较.

谢谢您的帮助

pit*_*all 3

对我来说,您可以按照以下步骤实现该层:

  1. 编写一个自定义ComputeCenter

    • 需要两个输入:i)。groudtruth 标签y_true(不是 one-hot 编码,而只是整数)和 ii)。预测成员资格y_pred

    • W包含一个大小数组的查找表num_classes x num_feats作为可训练权重(请参阅 BatchNormalization Layer),W[j] 是第 j 类特征的移动平均值的占位符。

    • 计算论文中指定的中心损耗。

    • 输出结果距离数组D
  2. 要计算中心损失,您需要

    • 我)。更新W[j]使用y_pred[k]根据y_true[k]=j,
    • ii). c_true[k]=W[j]检索样本的中心特征,y_pred[k]y_true[k]=j
    • y_prediii) 计算和之间的距离c_true
    • 这里c_true[k] = W[j], 和k是样本索引,j是 y_pred[k] 的真实标签。
  3. 用于model.add_loss()计算该损失。请注意,不要将此损失添加到 中model.compile( loss = ... )

最后,如果需要,您可以在中心损失中添加一些损失系数。