tf.keras 指标中的 reset_states() 和 update_state() 的含义是什么?

Lee*_*evo 1 python metrics keras tensorflow tensorflow2.0

我正在检查非常简单的指标对象,tensorflow.keras例如BinaryAccuracyAUC。他们都有reset_states()自己的update_state()观点,但我发现他们的文档不充分且不清楚。

你能解释一下它们的意思吗?

Nic*_*ais 7

update_state测量指标(均值、auc、准确度),并将它们存储在对象中,以便稍后可以通过以下方式检索result

import tensorflow as tf

mean_object = tf.metrics.Mean()

values = [1, 2, 3, 4, 5]

for ix, val in enumerate(values):
    mean_object.update_state(val)
    print(mean_object.result().numpy(), 'is the mean of', values[:ix+1])
Run Code Online (Sandbox Code Playgroud)
1.0 is the mean of [1]
1.5 is the mean of [1, 2]
2.0 is the mean of [1, 2, 3]
2.5 is the mean of [1, 2, 3, 4]
3.0 is the mean of [1, 2, 3, 4, 5]
Run Code Online (Sandbox Code Playgroud)

reset_states将指标重置为零:

mean_object.reset_states()
mean_object.result().numpy()
Run Code Online (Sandbox Code Playgroud)
0.0
Run Code Online (Sandbox Code Playgroud)

我不确定我说得比文档更清楚,在我看来,它已经解释得很好了。

例如,调用该对象mean_object([1, 2, 3, 4])将更新指标,返回result.

import tensorflow as tf

mean_object = tf.metrics.Mean()

values = [1, 2, 3, 4, 5]

print(mean_object.result())
returned_mean = mean_object(values)
print(mean_object.result())
print(returned_mean)
Run Code Online (Sandbox Code Playgroud)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)
Run Code Online (Sandbox Code Playgroud)