TensorFlow/Keras 的 fit() 函数的 class_weight 参数如何工作?

Emi*_*oss 4 keras tensorflow loss-function

我使用 TensorFlow 1.12 和 Keras 进行语义分割。我提供了一个权重向量(大小等于类的数量)来tf.keras.Model.fit()使用它的class_weight参数。我想知道这在内部是如何工作的。我使用自定义损失函数(骰子损失和焦点损失等),并且权重在输入损失函数之前不能与预测或单一热的基本事实相乘,因为这不会产生任何感觉。我的损失函数输出一个标量值,所以它也不能与函数输出相乘。那么在何处以及如何准确地考虑类权重?

我的自定义损失函数是:

def cross_entropy_loss(onehots_true, logits): # Inputs are [BATCH_SIZE, height, width, num_classes]
    logits, onehots_true = mask_pixels(onehots_true, logits) # Removes pixels for which no ground truth exists, and returns shape [num_gt_pixels, num_classes]
    return tf.losses.softmax_cross_entropy(onehots_true, logits)
Run Code Online (Sandbox Code Playgroud)

Shu*_*hal 5

正如在Keras 官方文档中提到的,

class_weight:可选字典将类索引(整数)映射到权重(浮点数)值,用于对损失函数进行加权(仅在训练期间)。这对于告诉模型“更多关注”来自代表性不足的类的样本很有用。

基本上,我们在类不平衡e 的情况下提供类权重。这意味着,训练样本在所有类中并不是均匀分布的。有些类的样本较少,而有些类的样本较多。

我们需要分类器更多地关注数量较少的类。一种方法可能是增加低样本类的损失值。更高的损失意味着更高的优化,从而导致有效的分类。

就 Keras 而言,我们将dict映射类索引传递给它们的权重(损失值将乘以的因子)。我们举个例子,

class_weights = { 0 : 1.2 , 1 : 0.9 }
Run Code Online (Sandbox Code Playgroud)

在内部,0 类和 1 类的损失值将乘以它们相应的权重值。

weighed_loss_class0 = loss0 * class_weights[0]
weighed_loss_class1 = loss1 * class_weights[1]
Run Code Online (Sandbox Code Playgroud)

现在,the weighed_loss_class0weighed_loss_class1将被用于反向传播。

看到这个这个


eug*_*gen 3

可以参考github上的keras源码,如下代码:

    class_sample_weight = np.asarray(
        [class_weight[cls] for cls in y_classes if cls in class_weight])

    if len(class_sample_weight) != len(y_classes):
      # subtract the sets to pick all missing classes
      existing_classes = set(y_classes)
      existing_class_weight = set(class_weight.keys())
      raise ValueError(
          '`class_weight` must contain all classes in the data.'
          ' The classes %s exist in the data but not in '
          '`class_weight`.' % (existing_classes - existing_class_weight))

  if class_sample_weight is not None and sample_weight is not None:
    # Multiply weights if both are provided.
    return class_sample_weight * sample_weight
Run Code Online (Sandbox Code Playgroud)

正如您所看到的,首先class_weight将其转换为 numpy 数组class_sample_weight,然后将其与样本权重相乘。

来源: https: //github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/engine/training_utils.py