在 Tensorflow 2.2 中使用 tf.metrics.MeanIoU() 和 SparseCategoricalCrossEntropy 损失时出现尺寸不匹配错误

Huc*_*inn 4 python tensorflow

参考#https://github.com/tensorflow/tensorflow/issues/32875

建议的修复方法是:

class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
    @tf.function
    def __call__(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.argmax(y_pred, axis=-1) # this is the fix
        return super().__call__(y_true, y_pred, sample_weight=sample_weight)
Run Code Online (Sandbox Code Playgroud)

它适用于 TF2.1,但在 TF2.2 中再次崩溃。有没有办法通过y_pred = tf.argmax(y_pred, axis=-1)y_pred这个度量不同于继承?

Huc*_*inn 7

这解决了这个问题:

class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
  def __init__(self,
               y_true=None,
               y_pred=None,
               num_classes=None,
               name=None,
               dtype=None):
    super(UpdatedMeanIoU, self).__init__(num_classes = num_classes,name=name, dtype=dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_pred = tf.math.argmax(y_pred, axis=-1)
    return super().update_state(y_true, y_pred, sample_weight)
Run Code Online (Sandbox Code Playgroud)