如何正确使用tf.metrics.accuracy?

Tho*_*aud 25 tensorflow

我使用accuracy函数来tf.metrics解决多重分类问题,并将logits作为输入.

我的模型输出如下:

logits = [[0.1, 0.5, 0.4],
          [0.8, 0.1, 0.1],
          [0.6, 0.3, 0.2]]
Run Code Online (Sandbox Code Playgroud)

我的标签是一个热门的编码载体:

labels = [[0, 1, 0],
          [1, 0, 0],
          [0, 0, 1]]
Run Code Online (Sandbox Code Playgroud)

当我尝试做类似的tf.metrics.accuracy(labels, logits)事情从来没有给出正确的结果.我显然做错了什么,但我无法弄清楚它是什么.

vij*_*y m 61

TL; DR

准确度函数tf.metrics.accuracy根据它创建的两个局部变量计算预测与标签匹配的频率:total并且count,用于计算logits匹配的频率labels.

acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1), 
                                  predictions=tf.argmax(logits,1))

print(sess.run([acc, acc_op]))
print(sess.run([acc]))
# Output
#[0.0, 0.66666669]
#[0.66666669]
Run Code Online (Sandbox Code Playgroud)
  • acc(准确性):简单地使用total和返回指标count,并不更新指标.
  • acc_op(更新):更新指标.

要了解acc返回的原因0.0,请查看以下详细信息.


细节使用一个简单的例子:

logits = tf.placeholder(tf.int64, [2,3])
labels = tf.Variable([[0, 1, 0], [1, 0, 1]])

acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1),   
                                  predictions=tf.argmax(logits,1))
Run Code Online (Sandbox Code Playgroud)

初始化变量:

由于metrics.accuracy创建了两个局部变量totalcount,我们需要调用local_variables_initializer()初始化它们.

sess = tf.Session()

sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())

stream_vars = [i for i in tf.local_variables()]
print(stream_vars)

#[<tf.Variable 'accuracy/total:0' shape=() dtype=float32_ref>,
# <tf.Variable 'accuracy/count:0' shape=() dtype=float32_ref>]
Run Code Online (Sandbox Code Playgroud)

了解更新操作和准确度计算:

print('acc:',sess.run(acc, {logits:[[0,1,0],[1,0,1]]}))
#acc: 0.0

print('[total, count]:',sess.run(stream_vars)) 
#[total, count]: [0.0, 0.0]
Run Code Online (Sandbox Code Playgroud)

尽管给出了匹配的输入,但上面的精度为0.0,total并且count为零.

print('ops:', sess.run(acc_op, {logits:[[0,1,0],[1,0,1]]})) 
#ops: 1.0

print('[total, count]:',sess.run(stream_vars)) 
#[total, count]: [2.0, 2.0]
Run Code Online (Sandbox Code Playgroud)

使用新输入时,将在调用更新操作时计算精度.注意:由于所有的logits和标签都匹配,我们得到1.0的准确度和局部变量total,count实际给出total correctly predictedtotal comparisons made.

现在我们accuracy使用新输入(而不是更新操作)调用:

print('acc:', sess.run(acc,{logits:[[1,0,0],[0,1,0]]}))
#acc: 1.0
Run Code Online (Sandbox Code Playgroud)

准确性调用不会使用新输入更新度量标准,它只使用两个局部变量返回值.注意:在这种情况下,logits和标签不匹配.现在再次调用update ops:

print('op:',sess.run(acc_op,{logits:[[0,1,0],[0,1,0]]}))
#op: 0.75 
print('[total, count]:',sess.run(stream_vars)) 
#[total, count]: [3.0, 4.0]
Run Code Online (Sandbox Code Playgroud)

指标更新为新输入


有关如何在培训期间使用指标以及如何在验证期间重置指标的更多信息,请参见此处.

  • 你想在最后一个维度得到最大值,所以它应该是`tf.argmax(logits,1)`和`tf.argmax(labels,1)` (4认同)
  • 我发现此[link](http://ronny.rest/blog/post_2017_09_11_tf_metrics/)对理解** tf.metrics.accuracy()**的实际作用非常有帮助。 (2认同)

小智 5

在 TF 2.0 上,如果您使用 tf.keras API,您可以定义一个继承自 tf.keras.metrics.Accuracy 的自定义类 myAccuracy,并重写更新方法,如下所示:

# imports
# ...
class myAccuracy(tf.keras.metrics.Accuracy):
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.argmax(y_true,1)
        y_pred = tf.argmax(y_pred,1)
        return super(myAccuracy,self).update_state(y_true,y_pred,sample_weight)
Run Code Online (Sandbox Code Playgroud)

然后,在编译模型时,您可以按照通常的方式添加指标。

from my_awesome_models import discriminador

discriminador.compile(tf.keras.optimizers.Adam(),
                      loss=tf.nn.softmax_cross_entropy_with_logits,
                      metrics=[myAccuracy()])

from my_puzzling_datasets import train_dataset,test_dataset

discriminador.fit(train_dataset.shuffle(70000).repeat().batch(1000), 
                  epochs=1,steps_per_epoch=1, 
                  validation_data=test_dataset.shuffle(70000).batch(1000), 
                  validation_steps=1)

# Train for 1 steps, validate for 1 steps
# 1/1 [==============================] - 3s 3s/step - loss: 0.1502 - accuracy: 0.9490 - val_loss: 0.1374 - val_accuracy: 0.9550
Run Code Online (Sandbox Code Playgroud)

或者在整个数据集上评估您的模型

discriminador.evaluate(test_dataset.batch(TST_DSET_LENGTH))
#> [0.131587415933609, 0.95354694]
Run Code Online (Sandbox Code Playgroud)