如何使用 tf.metrics 计算多标签分类的准确性?

tid*_*idy 5 evaluation tensorflow tensorflow-estimator

我想用张量流(tf.estimator.Estimator)训练多标签分类模型。我需要在评估时输出准确性。但它似乎不适用于以下代码:

accuracy = tf.metrics.accuracy(labels=labels, predictions=preds)
metrics = {'accuracy': accuracy}

if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
Run Code Online (Sandbox Code Playgroud)

tf.metrics.accuracy不适用于多重结果。那么什么是多标签指标呢?

Ami*_*mir 3

实际上tf.metrics.accuracy也计算了多标签分类的准确性。请参阅下面的示例:

import tensorflow as tf

labels = tf.constant([[1, 0, 0, 1],
                      [0, 1, 1, 1],
                      [1, 1, 0, 0],
                      [0, 0, 0, 1],
                      [1, 1, 0, 0]])

preds = tf.constant([[1, 0, 1, 1],
                     [0, 1, 1, 1],
                     [1, 1, 0, 0],
                     [0, 0, 0, 1],
                     [1, 1, 0, 0]])

acc, acc_op = tf.metrics.accuracy(labels, preds)

with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())
    print(sess.run([acc, acc_op]))
    print(sess.run([acc]))
Run Code Online (Sandbox Code Playgroud)

正如您所看到的,我们总共有 20 个标签,只有第一行中的一个条目被错误标记,因此我们的准确度为 0.95%。