Keras Custom Metric用于单级精度

Chr*_*rry 15 python neural-network keras tensorflow

我正在构建一个自定义指标来测量培训期间我的多类数据集中一个类的准确性.我在选择课程时遇到了麻烦.

目标是一个热点(例如:0级标签是[1 0 0 0 0]:

from keras import backend as K

def single_class_accuracy(y_true, y_pred):
    idx = bool(y_true[:, 0])              # boolean mask for class 0 
    class_preds = y_pred[idx]
    class_true = y_true[idx]
    class_acc = K.mean(K.equal(K.argmax(class_true, axis=-1), K.argmax(class_preds, axis=-1)))  # multi-class accuracy  
    return class_acc
Run Code Online (Sandbox Code Playgroud)

麻烦的是,我们必须使用Keras函数来索引张量.如何为张量创建布尔掩码?谢谢.

jde*_*esa 19

请注意,在谈论一个班级的准确性时,可以参考以下任一项(不等同于)两个金额:

  • 精度,这对于类Ç,是标记有类实例比率Ç被预测为具有类Ç.
  • 召回,其中,类Ç,是的例子的比率预测为类的Ç,它们事实上标记类Ç.

您可以依靠屏蔽来计算,而不是进行复杂的索引.假设我们在这里谈论精确度(改为召回将是微不足道的).

from keras import backend as K

INTERESTING_CLASS_ID = 0  # Choose the class of interest

def single_class_accuracy(y_true, y_pred):
    class_id_true = K.argmax(y_true, axis=-1)
    class_id_preds = K.argmax(y_pred, axis=-1)
    # Replace class_id_preds with class_id_true for recall here
    accuracy_mask = K.cast(K.equal(class_id_preds, INTERESTING_CLASS_ID), 'int32')
    class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask
    class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1)
    return class_acc
Run Code Online (Sandbox Code Playgroud)

如果您想要更灵活,您还可以将感兴趣的类别参数化:

from keras import backend as K

def single_class_accuracy(interesting_class_id):
    def fn(y_true, y_pred):
        class_id_true = K.argmax(y_true, axis=-1)
        class_id_preds = K.argmax(y_pred, axis=-1)
        # Replace class_id_preds with class_id_true for recall here
        accuracy_mask = K.cast(K.equal(class_id_preds, interesting_class_id), 'int32')
        class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask
        class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1)
        return class_acc
    return fn
Run Code Online (Sandbox Code Playgroud)

并将其用作:

model.compile(..., metrics=[single_class_accuracy(INTERESTING_CLASS_ID)])
Run Code Online (Sandbox Code Playgroud)

  • Precision和Recall可以组合,这个度量称为F1得分.它是精度和召回的调和平均值,它是测试精度的度量. (4认同)
  • 重要的是要注意,尽管F1得分(以及准确性和召回率)并未考虑到真正的负面因素。高度依赖于人们实际选择合适的度量标准。 (2认同)