我目前正在尝试训练谷歌的草图识别模型,只是链接中的一个:Github.但是我最近遇到的问题一直困扰着我很久.
问题如下:我使用了链接中的代码和来自quickdraw的数据来完成培训.我现在有一个训练有素的模型有三个文件(.meta,.index,.data),现在我想计算345个类别的训练模型的混淆矩阵.但由于我从未使用过张量流的"估算器",我不知道如何将训练好的模型文件加载到代码中并进行测试(没有训练),以及如何在softmax层之后获得分类分数(用于计算混淆矩阵).
"estimator"API让我困惑了很长时间.请在链接中的代码下解决我的问题:
def create_estimator_and_specs(run_config):
"""Creates an Experiment configuration based on the estimator and input fn."""
model_params = tf.contrib.training.HParams(
num_layers=FLAGS.num_layers,
num_nodes=FLAGS.num_nodes,
batch_size=FLAGS.batch_size,
num_conv=ast.literal_eval(FLAGS.num_conv),
conv_len=ast.literal_eval(FLAGS.conv_len),
num_classes=get_num_classes(),
learning_rate=FLAGS.learning_rate,
gradient_clipping_norm=FLAGS.gradient_clipping_norm,
cell_type=FLAGS.cell_type,
batch_norm=FLAGS.batch_norm,
dropout=FLAGS.dropout)
estimator = tf.estimator.Estimator(
model_fn=model_fn,
config=run_config,
params=model_params)
train_spec = tf.estimator.TrainSpec(
input_fn=get_input_fn(
mode=tf.estimator.ModeKeys.TRAIN,
tfrecord_pattern=FLAGS.training_data,
batch_size=FLAGS.batch_size),
max_steps=FLAGS.steps)
eval_spec = tf.estimator.EvalSpec(
input_fn=get_input_fn(
mode=tf.estimator.ModeKeys.EVAL,
tfrecord_pattern=FLAGS.eval_data,
batch_size=FLAGS.batch_size)
)
return estimator, train_spec, eval_spec
def main(unused_args):
estimator, train_spec, eval_spec = create_estimator_and_specs(
run_config=tf.estimator.RunConfig(
model_dir=FLAGS.model_dir,
save_checkpoints_secs=300,
save_summary_steps=100)
)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
Run Code Online (Sandbox Code Playgroud)
我想将训练好的模型加载到上面的代码中,并计算345个类别的混淆矩阵.
您可以使用库函数tf.confusion_matrix
tf.confusion_matrix(
labels,
predictions,
num_classes=None,
dtype=tf.int32,
name=None,
weights=None
)
Run Code Online (Sandbox Code Playgroud)
根据预测和标签计算混淆矩阵。
tf.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
[[0 0 0 0 0]
[0 0 1 0 0]
[0 0 1 0 0]
[0 0 0 0 0]
[0 0 0 0 1]]
Run Code Online (Sandbox Code Playgroud)
针对您的情况,以下代码可能会帮助您:
tf.confusion_matrix(
labels,
predictions,
num_classes=None,
dtype=tf.int32,
name=None,
weights=None
)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
308 次 |
| 最近记录: |