iza*_*zak 4 keras tensorflow tf.keras
我正在尝试将用 tensorflow 层编写的 CNN 转换为在 tensorflow 中使用 keras api(我使用的是 TF 1.x 提供的 keras api),并且在编写自定义损失函数以训练模型时遇到问题。
根据本指南,在定义损失函数时,它需要参数(y_true, y_pred)
https://www.tensorflow.org/guide/keras/train_and_evaluate#custom_losses
def basic_loss_function(y_true, y_pred):
return ...
Run Code Online (Sandbox Code Playgroud)
然而,在我见过的每个例子中,y_true
都以某种方式与模型直接相关(在简单的情况下,它是网络的输出)。在我的问题中,情况并非如此。如果我的损失函数依赖于一些与模型张量无关的训练数据,如何实现这一点?
具体来说,这是我的问题:
我正在尝试学习在成对图像上训练的图像嵌入。我的训练数据包括图像对和图像对之间匹配点的注释(图像坐标)。输入特征只是图像对,网络以连体配置进行训练。
我能够用张量流层成功地实现这一点,并用张量流估计器成功地训练它。我当前的实现从一个大型 tf 记录数据库构建了一个 tf 数据集,其中的特征是一个包含图像和匹配点数组的字典。在我可以轻松地将这些图像坐标数组提供给损失函数之前,但目前尚不清楚如何执行此操作。
我经常使用的一种技巧是通过Lambda
层来计算模型内的损失。(例如,当损失与真实数据无关,并且模型实际上没有要比较的输出时)
在函数式 API 模型中:
def loss_calc(x):
loss_input_1, loss_input_2 = x #arbirtray inputs, you choose
#according to what you gave to the Lambda layer
#here you use some external data that doesn't relate to the samples
externalData = K.constant(external_numpy_data)
#calculate the loss
return the loss
Run Code Online (Sandbox Code Playgroud)
使用模型本身的输出(损失中使用的张量)
loss = Lambda(loss_calc)([model_output_1, model_output_2])
Run Code Online (Sandbox Code Playgroud)
创建输出损失而不是输出的模型:
model = Model(inputs, loss)
Run Code Online (Sandbox Code Playgroud)
为编译创建一个虚拟的 keras 损失函数:
def dummy_loss(y_true, y_pred):
return y_pred #where y_pred is the loss itself, the output of the model above
model.compile(loss = dummy_loss, ....)
Run Code Online (Sandbox Code Playgroud)
使用任何关于训练样本数量正确大小的虚拟数组,它将被忽略:
model.fit(your_inputs, np.zeros((number_of_samples,)), ...)
Run Code Online (Sandbox Code Playgroud)
另一种方法是使用自定义训练循环。
不过,这需要做更多的工作。
尽管您正在使用TF1
,但您仍然可以在代码的最开始就开启急切执行,并像在TF2
. ( tf.enable_eager_execution()
)
按照自定义训练循环教程:https : //www.tensorflow.org/tutorials/customization/custom_training_walkthrough
在这里,您可以自己计算梯度,无论您想要什么结果。这意味着您无需遵循 Keras 培训标准。
最后,您可以使用您建议的方法model.add_loss
。在这种情况下,您可以像我在第一个答案中所做的那样计算损失精确度。并将这个损失张量传递给add_loss
。
您可能可以使用loss=None
then (不确定)编译模型,因为您将使用其他损失,而不是标准损失。
在这种情况下,您的模型的输出可能也是None
如此,您应该使用y=None
.
归档时间: |
|
查看次数: |
2704 次 |
最近记录: |