tim*_*xyz 6 python neural-network tensorflow
我正在按照教程进行操作,可以浏览代码,该代码可以训练神经网络并评估其准确性.
但我不知道如何在新的单个输入(字符串)上使用训练模型来预测其标签.
你能告诉我们如何做到这一点吗?
教程:
会话代码:
# Launch the graph
with tf.Session() as sess:
sess.run(init)
# Training cycle
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(len(newsgroups_train.data)/batch_size)
# Loop over all batches
for i in range(total_batch):
batch_x,batch_y = get_batch(newsgroups_train,i,batch_size)
# Run optimization op (backprop) and cost op (to get loss value)
c,_ = sess.run([loss,optimizer], feed_dict={input_tensor: batch_x,output_tensor:batch_y})
# Compute average loss
avg_cost += c / total_batch
# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch+1), "loss=", \
"{:.9f}".format(avg_cost))
print("Optimization Finished!")
# Test model
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(output_tensor, 1))
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
total_test_data = len(newsgroups_test.target)
batch_x_test,batch_y_test = get_batch(newsgroups_test,0,total_test_data)
print("Accuracy:", accuracy.eval({input_tensor: batch_x_test, output_tensor: batch_y_test}))
Run Code Online (Sandbox Code Playgroud)
我有一些Python的经验,但基本上没有Tensorflow的经验.
首先我们需要将文本转换为数组:
def text_to_vector(text):
layer = np.zeros(total_words,dtype=float)
for word in text.split(' '):
layer[word2index[word.lower()]] += 1
return layer
# Convert text to vector so we can send it to our model
vector_txt = text_to_vector(text)
# Wrap vector like we do in get_batches()
input_array = np.array([vector_txt])
Run Code Online (Sandbox Code Playgroud)
我们可以保存和加载模型以供重用。我们首先创建一个 Saver 对象,然后保存会话(在模型训练之后):
saver = tf.train.Saver()
... train the model ...
save_path = saver.save(sess, "/tmp/model.ckpt")
Run Code Online (Sandbox Code Playgroud)
在示例模型中,模型架构中的最后一个“步骤”(即方法内完成的最后一件事multilayer_perceptron
)是:
'out': tf.Variable(tf.random_normal([n_classes]))
Run Code Online (Sandbox Code Playgroud)
因此,为了获得预测,我们获取该数组最大值的索引(预测类别):
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
classification = sess.run(tf.argmax(prediction, 1), feed_dict={input_tensor: input_array})
print("Predicted category:", classification)
Run Code Online (Sandbox Code Playgroud)
您可以在此处查看完整代码:https ://github.com/dmesquita/understanding_tensorflow_nn
归档时间: |
|
查看次数: |
326 次 |
最近记录: |