Emi*_*oss 4 keras tensorflow loss-function
我使用 TensorFlow 1.12 和 Keras 进行语义分割。我提供了一个权重向量(大小等于类的数量)来tf.keras.Model.fit()
使用它的class_weight
参数。我想知道这在内部是如何工作的。我使用自定义损失函数(骰子损失和焦点损失等),并且权重在输入损失函数之前不能与预测或单一热的基本事实相乘,因为这不会产生任何感觉。我的损失函数输出一个标量值,所以它也不能与函数输出相乘。那么在何处以及如何准确地考虑类权重?
我的自定义损失函数是:
def cross_entropy_loss(onehots_true, logits): # Inputs are [BATCH_SIZE, height, width, num_classes]
logits, onehots_true = mask_pixels(onehots_true, logits) # Removes pixels for which no ground truth exists, and returns shape [num_gt_pixels, num_classes]
return tf.losses.softmax_cross_entropy(onehots_true, logits)
Run Code Online (Sandbox Code Playgroud)
正如在Keras 官方文档中提到的,
class_weight
:可选字典将类索引(整数)映射到权重(浮点数)值,用于对损失函数进行加权(仅在训练期间)。这对于告诉模型“更多关注”来自代表性不足的类的样本很有用。
基本上,我们在类不平衡e 的情况下提供类权重。这意味着,训练样本在所有类中并不是均匀分布的。有些类的样本较少,而有些类的样本较多。
我们需要分类器更多地关注数量较少的类。一种方法可能是增加低样本类的损失值。更高的损失意味着更高的优化,从而导致有效的分类。
就 Keras 而言,我们将dict
映射类索引传递给它们的权重(损失值将乘以的因子)。我们举个例子,
class_weights = { 0 : 1.2 , 1 : 0.9 }
Run Code Online (Sandbox Code Playgroud)
在内部,0 类和 1 类的损失值将乘以它们相应的权重值。
weighed_loss_class0 = loss0 * class_weights[0]
weighed_loss_class1 = loss1 * class_weights[1]
Run Code Online (Sandbox Code Playgroud)
现在,the weighed_loss_class0
和weighed_loss_class1
将被用于反向传播。
可以参考github上的keras源码,如下代码:
class_sample_weight = np.asarray(
[class_weight[cls] for cls in y_classes if cls in class_weight])
if len(class_sample_weight) != len(y_classes):
# subtract the sets to pick all missing classes
existing_classes = set(y_classes)
existing_class_weight = set(class_weight.keys())
raise ValueError(
'`class_weight` must contain all classes in the data.'
' The classes %s exist in the data but not in '
'`class_weight`.' % (existing_classes - existing_class_weight))
if class_sample_weight is not None and sample_weight is not None:
# Multiply weights if both are provided.
return class_sample_weight * sample_weight
Run Code Online (Sandbox Code Playgroud)
正如您所看到的,首先class_weight
将其转换为 numpy 数组class_sample_weight
,然后将其与样本权重相乘。
来源: https: //github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/engine/training_utils.py
归档时间: |
|
查看次数: |
5640 次 |
最近记录: |