小编use*_*012的帖子

如何使用数据集API在Tensorflow中将Iterator的输出映射到丢失函数中的占位符

以下是tensorflow网站关于使用数据集api来消费来自tfrecords的数据的代码

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)

iterator = dataset.make_one_shot_iterator()
next_example, next_label = iterator.get_next()

loss = model_function(next_example, next_label)
training_op = tf.train.AdagradOptimizer(...).minimize(loss)

with tf.train.MonitoredTrainingSession(...) as sess:
  while not sess.should_stop
Run Code Online (Sandbox Code Playgroud)

通常我将我的网络定义为

x = tf.placeholder(tf.float32, [None, INPUT_SIZE], name='INPUT')
y_ = tf.placeholder(tf.float32, [None, OUTPUT_SIZE], name='OUTPUT')

w1 = tf.Variable(tf.truncated_normal([INPUT_SIZE, L1_SIZE], stddev=0.1))
b1 = tf.Variable(tf.constant(0.1, shape=[L1_SIZE]))
w2 = tf.Variable(tf.truncated_normal([L1_SIZE, L2_SIZE], stddev=0.1))
b2 = tf.Variable(tf.constant(0.1, shape=[L2_SIZE]))

w3 = tf.Variable(tf.truncated_normal([L2_SIZE, OUTPUT_SIZE], stddev=0.1))
b3 = tf.Variable(tf.constant(0.1, …
Run Code Online (Sandbox Code Playgroud)

python dataset tensorflow

2
推荐指数
1
解决办法
1476
查看次数

标签 统计

dataset ×1

python ×1

tensorflow ×1