Keras实现定制损失功能,需要内部层输出作为标签

ljk*_*ece 8 loss keras

在keras中,我想自定义我的损失函数,它不仅需要(y_true,y_pred)作为输入,还需要使用网络内层的输出作为输出图层的标签.这张图片显示了网络布局

这里,内部输出是xn,它是一维特征向量.在右上角,输出是xn',这是xn的预测.换句话说,xn是xn'的标签.

虽然[Ax,Ay]传统上称为y_true,而[Ax',Ay']是y_pred.

我想将这两个损失组件合二为一,共同培训网络.

任何想法或想法都非常感谢!

ljk*_*ece 12

我找到了一条出路,如果有人正在寻找相同的,我在这里发布(根据这篇文章中给出的网络):

我们的想法是定义自定义损失函数并将其用作网络的输出.(符号:A是可变的正确标签A,并且A'是可变的预测值A)

def customized_loss(args):
    #A is from the training data
    #S is the internal state
    A, A', S, S' = args 
    #customize your own loss components
    loss1 = K.mean(K.square(A - A'), axis=-1)
    loss2 = K.mean(K.square(S - S'), axis=-1)
    #adjust the weight between loss components
    return 0.5 * loss1 + 0.5 * loss2

 def model():
     #define other inputs
     A = Input(...) # define input A
     #construct your model 
     cnn_model = Sequential()
     ...
     # get true internal state
     S = cnn_model(prev_layer_output0)
     # get predicted internal state output
     S' = Dense(...)(prev_layer_output1)
     # get predicted A output
     A' = Dense(...)(prev_layer_output2)
     # customized loss function
     loss_out = Lambda(customized_loss, output_shape=(1,), name='joint_loss')([A, A', S, S'])
     model = Model(input=[...], output=[loss_out])
     return model

  def train():
      m = model()
      opt = 'adam'
      model.compile(loss={'joint_loss': lambda y_true, y_pred:y_pred}, optimizer = opt)
      # train the model 
      ....
Run Code Online (Sandbox Code Playgroud)


Mat*_*gro 0

首先,您应该使用功能 API。然后,您应该将网络输出定义为输出加上内部层的结果,将它们合并为单个输出(通过连接),然后创建一个自定义损失函数,将合并的输出分成两部分并进行损失计算在其自己的。

就像是:

def customLoss(y_true, y_pred):
    #loss here
    internalLayer = Convolution2D()(inputs) #or other layers
    internalModel = Model(input=inputs, output=internalLayer)
    tmpOut = Dense(...)(internalModel)
    mergedOut = merge([tmpOut, mergedOut], mode = "concat", axis = -1)
    fullModel = Model(input=inputs, output=mergedOut)

    fullModel.compile(loss = customLoss, optimizer = "whatever")
Run Code Online (Sandbox Code Playgroud)