参考#https://github.com/tensorflow/tensorflow/issues/32875
建议的修复方法是:
class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
@tf.function
def __call__(self, y_true, y_pred, sample_weight=None):
y_pred = tf.argmax(y_pred, axis=-1) # this is the fix
return super().__call__(y_true, y_pred, sample_weight=sample_weight)
Run Code Online (Sandbox Code Playgroud)
它适用于 TF2.1,但在 TF2.2 中再次崩溃。有没有办法通过y_pred = tf.argmax(y_pred, axis=-1)为y_pred这个度量不同于继承?
这解决了这个问题:
class UpdatedMeanIoU(tf.keras.metrics.MeanIoU):
def __init__(self,
y_true=None,
y_pred=None,
num_classes=None,
name=None,
dtype=None):
super(UpdatedMeanIoU, self).__init__(num_classes = num_classes,name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.math.argmax(y_pred, axis=-1)
return super().update_state(y_true, y_pred, sample_weight)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
334 次 |
| 最近记录: |