tensorflow session.run() 的参数 - 你传递操作了吗?

arc*_*tom 5 python tensorflow

我正在关注tensorflow 的教程

我试图理解tf.session.run(). 我知道您必须在会话中的图形中运行操作。

train_step的,因为它封装了这个特殊的例子在网络的所有操作通过呢?我试图理解为什么我不需要将任何其他变量传递给会话,例如cross_entropy.

sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
Run Code Online (Sandbox Code Playgroud)

这是完整的代码:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

import tensorflow as tf

x = tf.placeholder(tf.float32, [None, 784])

W = tf.Variable(tf.zeros([784, 10]))

b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b)

y_ = tf.placeholder(tf.float32, [None, 10])

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

sess = tf.InteractiveSession()

tf.global_variables_initializer().run()

for _ in range(10):
    batch_xs, batch_ys = mnist.train.next_batch(100)

    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
Run Code Online (Sandbox Code Playgroud)

Eka*_*ong 8

在 TensorFlow Session 中tf.Session,您希望运行(或执行)优化器操作(在本例中为train_step)。优化器最小化您的损失函数(在本例中为cross_entropy),该函数使用模型假设 进行评估或计算y

在级联方法中,cross_entropy损失函数最小化了计算时的误差y,因此它找到了Wx精确近似的权重的最佳值y

因此,tf.Sessionsess我们运行优化器时使用 TensorFlow Session 对象train_step,然后它会评估整个计算图。

sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
Run Code Online (Sandbox Code Playgroud)

因为级联方法最终会调用cross_entropywhich 使用占位符xand y,所以您必须使用feed_dict将数据传递给这些占位符。