Sam*_*amy 2 conv-neural-network tensorflow
我看过一些关于恢复模型和导出图表的文档页面的帖子,但我想我错过了一些东西.TF
Google
我使用此Gist中的代码来保存模型以及定义模型的此utils文件
现在我想恢复它并运行以前看不见的测试数据,如下所示:
def evaluate(X_data, y_data):
num_examples = len(X_data)
total_accuracy = 0
total_loss = 0
sess = tf.get_default_session()
acc_steps = len(X_data) // BATCH_SIZE
for i in range(acc_steps):
batch_x, batch_y = next_batch(X_val, Y_val, BATCH_SIZE)
loss, accuracy = sess.run([loss_value, acc], feed_dict={
images_placeholder: batch_x,
labels_placeholder: batch_y,
keep_prob: 0.5
})
total_accuracy += (accuracy * len(batch_x))
total_loss += (loss * len(batch_x))
return (total_accuracy / num_examples, total_loss / num_examples)
## re-execute the code that defines the model
# Image Tensor
images_placeholder = tf.placeholder(tf.float32, shape=[None, 32, 32, 3], name='x')
gray = tf.image.rgb_to_grayscale(images_placeholder, name='gray')
gray /= 255.
# Label Tensor
labels_placeholder = tf.placeholder(tf.float32, shape=(None, 43), name='y')
# dropout Tensor
keep_prob = tf.placeholder(tf.float32, name='drop')
# construct model
logits = inference(gray, keep_prob)
# calculate loss
loss_value = loss(logits, labels_placeholder)
# training
train_op = training(loss_value, 0.001)
# accuracy
acc = accuracy(logits, labels_placeholder)
with tf.Session() as sess:
loader = tf.train.import_meta_graph('gtsd.meta')
loader.restore(sess, tf.train.latest_checkpoint('./'))
sess.run(tf.initialize_all_variables())
test_accuracy = evaluate(X_test, y_test)
print("Test Accuracy = {:.3f}".format(test_accuracy[0]))
Run Code Online (Sandbox Code Playgroud)
我的测试精度只有3%.但是,如果我没有关闭笔记本并在训练模型后立即运行测试代码,我的准确率为95%.
这让我相信我没有正确加载模型?
这两个问题产生了问题:
loader.restore(sess, tf.train.latest_checkpoint('./'))
sess.run(tf.initialize_all_variables())
Run Code Online (Sandbox Code Playgroud)
第一行从检查点加载已保存的模型.第二行重新初始化模型中的所有变量(例如权重矩阵,卷积滤波器和偏置矢量),通常是随机数,并覆盖加载的值.
解决方案很简单:删除第二行(sess.run(tf.initialize_all_variables())
),评估将继续使用从检查点加载的训练值.
PS.这种变化很可能会给你一个关于"未初始化变量"的错误.在这种情况下,您应该执行sess.run(tf.initialize_all_variables())
以初始化在执行之前未保存在检查点中的任何变量loader.restore(sess, tf.train.latest_checkpoint('./'))
.
归档时间: |
|
查看次数: |
3936 次 |
最近记录: |