评估过程中实验者的张量流混淆矩阵

Sam*_*mbo 3 python neural-network confusion-matrix tensorflow

在使用Tensorflow和Experimenter API进行模型评估期间,我遇到了一些麻烦.

我以前使用2级NN工作,但这次我设法训练4级,我需要弄清楚如何在这种情况下构建混淆矩阵.我尝试使用该tf.confusion_matrix功能,但它根本不起作用.

这是我使用的代码片段:

if mode == ModeKeys.EVAL:

    eval_metric_ops = {
        'accuracy' : metrics.streaming_accuracy(predictions=predicted_classes, labels=labels),

        # Other metrics...

        'confusion_matrix': tf.confusion_matrix(prediction=predicted_classes, label=labels, num_classes=4)
    }

    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predicted_classes,
        loss=loss,
        eval_metric_ops=eval_metric_ops
    )
Run Code Online (Sandbox Code Playgroud)

这是我得到的错误:

TypeError: Values of eval_metric_ops must be (metric_value, update_op) tuples, given: (<tf.Operation 'test/group_deps' type=NoOp>, <tf.Tensor 'test/accuracy/value:0' shape=() dtype=float32>, <tf.Variable 'test/confusion:0' shape=(4, 4) dtype=int32_ref>) for key: confusion_matrix
Run Code Online (Sandbox Code Playgroud)

我读了关于在Tensorflow中创建混淆矩阵的其他答案,我理解了如何做,但我认为我的问题与Estimator/Experimenter API更相关.

mon*_*chi 5

您的代码不起作用,因为框架期望eval_metric_ops是包含操作名称和类型元组值的键的字典(结果张量,此张量的update_operation)

tf.confusion_matrix(prediction = predict_classes,label = labels,num_classes = 4)仅返回预期的张量.

您必须实现自己的度量操作,如下所示:

def eval_confusion_matrix(labels, predictions):
    with tf.variable_scope("eval_confusion_matrix"):
        con_matrix = tf.confusion_matrix(labels=labels, predictions=predictions, num_classes=4)

        con_matrix_sum = tf.Variable(tf.zeros(shape=(4,4), dtype=tf.int32),
                                            trainable=False,
                                            name="confusion_matrix_result",
                                            collections=[tf.GraphKeys.LOCAL_VARIABLES])


        update_op = tf.assign_add(con_matrix_sum, con_matrix)

        return tf.convert_to_tensor(con_matrix_sum), update_op



# Add evaluation metrics (for EVAL mode)
eval_metric_ops = {
    "accuracy": tf.metrics.accuracy(labels, predicted_classes),
    "conv_matrix": eval_confusion_matrix(
        labels, predicted_classes)
    }
Run Code Online (Sandbox Code Playgroud)