我正在尝试学习结合 VGG 和Adrian Ung Triplet Loss的 Paris6k 图像的嵌入。问题是,经过少量迭代后,在第一个 epoch 中,损失变为 nan,然后准确率和验证准确率增长到 1。
我已经尝试过降低学习率、增加批量大小(由于内存的原因仅增加到 16)、更改优化器(Adam 和 RMSprop)、检查数据集上是否有 None 值、将数据格式从“float32”更改为“ float64',为它们添加一点偏差并简化模型。
这是我的代码:
base_model = VGG16(include_top = False, input_shape = (512, 384, 3))
input_images = base_model.input
input_labels = Input(shape=(1,), name='input_label')
embeddings = Flatten()(base_model.output)
labels_plus_embeddings = concatenate([input_labels, embeddings])
model = Model(inputs=[input_images, input_labels], outputs=labels_plus_embeddings)
batch_size = 16
epochs = 2
embedding_size = 64
opt = Adam(lr=0.0001)
model.compile(loss=tl.triplet_loss_adapted_from_tf, optimizer=opt, metrics=['accuracy'])
label_list = np.vstack(label_list)
x_train = image_list[:2500]
x_val = image_list[2500:]
y_train = label_list[:2500]
y_val = label_list[2500:]
dummy_gt_train = np.zeros((len(x_train), embedding_size + 1))
dummy_gt_val = np.zeros((len(x_val), embedding_size + 1))
H = model.fit(
x=[x_train,y_train],
y=dummy_gt_train,
batch_size=batch_size,
epochs=epochs,
validation_data=([x_val, y_val], dummy_gt_val),callbacks=callbacks_list)
Run Code Online (Sandbox Code Playgroud)
这些图像为 3366,其值在 [0, 1] 范围内缩放。网络采用虚拟值,因为它试图以相同类别的图像应具有较小的距离,而不同类别的图像应具有较高的距离,并且比真实类别是训练的一部分的方式从图像中学习嵌入。
我注意到我之前进行了错误的类划分(并保留了应该丢弃的图像),并且我没有出现纳米损失问题。
我应该尝试做什么?
预先感谢并抱歉我的英语。
小智 5
在某些情况下,随机 NaN 损失可能是由您的数据引起的,因为如果您的批次中没有正对,您将得到 NaN 损失。
正如您在 Adrian Ung 的笔记本中看到的那样(或在张量流插件三元组损失中;它是相同的代码):
semi_hard_triplet_loss_distance = math_ops.truediv(
math_ops.reduce_sum(
math_ops.maximum(
math_ops.multiply(loss_mat, mask_positives), 0.0)),
num_positives,
name='triplet_semihard_loss')
Run Code Online (Sandbox Code Playgroud)
除以正数对的数量 ( num_positives),这可能会导致 NaN。
我建议您尝试检查您的数据管道,以确保每个批次中至少有一对阳性。(例如,您可以调整 中的一些代码来triplet_loss_adapted_from_tf获取num_positives批次的 ,并检查它是否大于 0)。
小智 5
尝试增加批量大小。这也发生在我身上。正如前面的答案中提到的,网络无法找到任何 num_positives。我上了 250 节课,一开始就损失惨重。我把它增加到128/256,然后就没有问题了。
我看到Paris6k有15个班或者12个班。将批量大小增加到 32,如果 GPU 内存不足,您可以尝试使用参数较少的模型。您可以从高效 B0 模型开始。与具有 138M 参数的 VGG16 相比,它有 5.3M。
| 归档时间: |
|
| 查看次数: |
1957 次 |
| 最近记录: |