如何使用估计器获得用于混淆矩阵的类分数?

Shu*_*ang 5 python tensorflow

我目前正在尝试训练谷歌的草图识别模型,只是链接中的一个:Github.但是我最近遇到的问题一直困扰着我很久.

问题如下:我使用了链接中的代码和来自quickdraw的数据来完成培训.我现在有一个训练有素的模型有三个文件(.meta,.index,.data),现在我想计算345个类别的训练模型的混淆矩阵.但由于我从未使用过张量流的"估算器",我不知道如何将训练好的模型文件加载到代码中并进行测试(没有训练),以及如何在softmax层之后获得分类分数(用于计算混淆矩阵).

"estimator"API让我困惑了很长时间.请在链接中的代码下解决我的问题:

def create_estimator_and_specs(run_config):
    """Creates an Experiment configuration based on the estimator and input fn."""
    model_params = tf.contrib.training.HParams(
        num_layers=FLAGS.num_layers,
        num_nodes=FLAGS.num_nodes,
        batch_size=FLAGS.batch_size,
        num_conv=ast.literal_eval(FLAGS.num_conv),
        conv_len=ast.literal_eval(FLAGS.conv_len),
        num_classes=get_num_classes(),
        learning_rate=FLAGS.learning_rate,
        gradient_clipping_norm=FLAGS.gradient_clipping_norm,
        cell_type=FLAGS.cell_type,
        batch_norm=FLAGS.batch_norm,
        dropout=FLAGS.dropout)
    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
        config=run_config,
        params=model_params)
    train_spec = tf.estimator.TrainSpec(
        input_fn=get_input_fn(
            mode=tf.estimator.ModeKeys.TRAIN,
            tfrecord_pattern=FLAGS.training_data,
            batch_size=FLAGS.batch_size),
        max_steps=FLAGS.steps)
    eval_spec = tf.estimator.EvalSpec(
        input_fn=get_input_fn(
            mode=tf.estimator.ModeKeys.EVAL,
            tfrecord_pattern=FLAGS.eval_data,
            batch_size=FLAGS.batch_size)
        )
    return estimator, train_spec, eval_spec

def main(unused_args):
    estimator, train_spec, eval_spec = create_estimator_and_specs(
        run_config=tf.estimator.RunConfig(
            model_dir=FLAGS.model_dir,
            save_checkpoints_secs=300,
            save_summary_steps=100)
        )
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
Run Code Online (Sandbox Code Playgroud)

我想将训练好的模型加载到上面的代码中,并计算345个类别的混淆矩阵.

Arj*_*ava 1

您可以使用库函数tf.confusion_matrix

tf.confusion_matrix(
    labels,
    predictions,
    num_classes=None,
    dtype=tf.int32,
    name=None,
    weights=None
)
Run Code Online (Sandbox Code Playgroud)

根据预测和标签计算混淆矩阵。

tf.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
      [[0 0 0 0 0]
       [0 0 1 0 0]
       [0 0 1 0 0]
       [0 0 0 0 0]
       [0 0 0 0 1]]
Run Code Online (Sandbox Code Playgroud)

针对您的情况,以下代码可能会帮助您:

tf.confusion_matrix(
    labels,
    predictions,
    num_classes=None,
    dtype=tf.int32,
    name=None,
    weights=None
)
Run Code Online (Sandbox Code Playgroud)