如何访问生成器提供的 Keras 自定义损失函数中的样本权重?

ely*_*ely 5 python deep-learning keras tensorflow loss-function

我有一个生成器函数,它在某些图像目录上无限循环并输出 3 元组的形式

[img1, img2], label, weight
Run Code Online (Sandbox Code Playgroud)

其中img1img2batch_size x M x N x 3张量,并且labelweight各自batch_size的X 1张量。

fit_generator在用 Keras 训练模型时向函数提供了这个生成器。

对于这个模型,我有一个自定义的余弦对比损失函数,

def cosine_constrastive_loss(y_true, y_pred):
    cosine_distance = 1 - y_pred
    margin = 0.9
    cdist = y_true * y_pred + (1 - y_true) * keras.backend.maximum(margin - y_pred, 0.0)
    return keras.backend.mean(cdist)
Run Code Online (Sandbox Code Playgroud)

从结构上讲,我的模型一切正常。没有错误,它正在按预期消耗来自生成器的输入和标签。

但现在我正在寻求直接使用每个批次的权重参数,并cosine_contrastive_loss根据特定于样本的权重在内部执行一些自定义逻辑。

如何在执行损失函数时从一批样本的结构中访问此参数?

请注意,由于它是一个无限循环的生成器,因此无法预先计算权重或动态计算它们以将权重归入损失函数或生成它们。

它们具有一致地产生具有所产生的样品,并确有定制逻辑在我的数据生成器,从性能动态地确定的权重img1img2并且label在此刻它们用于分批生成。

Dan*_*ler 5

手动训练循环替代

我唯一能想到的是手动训练循环,您可以自己获得重量。

有一个权重张量和一个不可变的批量大小:

weights = K.variable(np.zeros((batch_size,)))
Run Code Online (Sandbox Code Playgroud)

在您的自定义损失中使用它们:

def custom_loss(true, pred):
    return someCalculation(true, pred, weights)
Run Code Online (Sandbox Code Playgroud)

对于“生成器”:

for e in range(epochs):
    for s in range(steps_per_epoch):
        x, y, w = next(generator) #or generator.next(), not sure
        K.set_value(weights, w)

        model.train_on_batch(x, y)
Run Code Online (Sandbox Code Playgroud)

对于keras.utils.Sequence

for e in range(epochs):
    for s in range(len(generator)):
        x,y,w = generator[s]

        K.set_value(weights, w)
        model.train_on_batch(x,y)
Run Code Online (Sandbox Code Playgroud)

我知道这个答案不是最优的,因为它不会并行从生成器获取数据,因为它发生在fit_generator. 但这是我能想到的最好的简单解决方案。Keras 没有公开权重,它们会自动应用在一些隐藏的源代码中。


让模型计算权重替代方案

如果可以从x和计算权重y,则可以将此任务委托给损失函数本身。

这有点hacky,但可能有效:

input1 = Input(shape1)
input2 = Input(shape2)

# .... model creation .... #

model = Model([input1, input2], outputs)
Run Code Online (Sandbox Code Playgroud)

让损失可以访问input1input2

def custom_loss(y_true, y_pred):
    w = calculate_weights(input1, input2, y_pred)
    # .... rest of the loss .... #
Run Code Online (Sandbox Code Playgroud)

这里的问题是您是否可以根据输入将权重计算为张量。