小编Phi*_*ton的帖子

对张量流损失类 (categorical_crossentropy) 进行子分类以创建加权损失函数时出现意外的关键字参数“sample_weight”

努力让子类损失函数在 Tensorflow (2.2.0) 中工作。最初尝试了这段代码(我知道它对其他人有用 - 请参阅https://github.com/keras-team/keras/issues/2115#issuecomment-530762739):

import tensorflow.keras.backend as K
from tensorflow.keras.losses import CategoricalCrossentropy


class WeightedCategoricalCrossentropy(CategoricalCrossentropy):

    def __init__(self, cost_mat, name='weighted_categorical_crossentropy', **kwargs):
        assert(cost_mat.ndim == 2)
        assert(cost_mat.shape[0] == cost_mat.shape[1])

        super().__init__(name=name, **kwargs)
        self.cost_mat = K.cast_to_floatx(cost_mat)

    def __call__(self, y_true, y_pred):

        return super().__call__(
            y_true=y_true,
            y_pred=y_pred,
            sample_weight=get_sample_weights(y_true, y_pred, self.cost_mat),
        )

def get_sample_weights(y_true, y_pred, cost_m):
    num_classes = len(cost_m)

    y_pred.shape.assert_has_rank(2)
    y_pred.shape[1].assert_is_compatible_with(num_classes)
    y_pred.shape.assert_is_compatible_with(y_true.shape)

    y_pred = K.one_hot(K.argmax(y_pred), num_classes)

    y_true_nk1 = K.expand_dims(y_true, 2)
    y_pred_n1k = K.expand_dims(y_pred, 1)
    cost_m_1kk = K.expand_dims(cost_m, 0)

    sample_weights_nkk = cost_m_1kk * y_true_nk1 * y_pred_n1k
    sample_weights_n …
Run Code Online (Sandbox Code Playgroud)

subclassing tensorflow2.0

7
推荐指数
1
解决办法
4681
查看次数

标签 统计

subclassing ×1

tensorflow2.0 ×1