为了评估测试数据,对数据集进行单次传递的最佳方法是什么?我想避免在python中编写数据加载脚本并使用feed_dict
.相反,我想使用所有漂亮的TF基础设施进行排队,批处理等.
在cifar 示例中,测试示例的数量是硬编码的,并且代码采取num_test_examples/batch_size
步骤以进行评估.使用批处理基础架构似乎应该有更好的方法来实现这一点.
似乎标准模式是在捕获队列抛出的异常时停止运行.我已经尝试了一些东西,这样当没有更多的例子来填充队列时队列会抱怨(即生产者不能再生产).这不是你想要捕获的例外.当消费者没有任何东西要消耗时,你想要捕获,即队列是空的.我该怎么做呢?
此外,如果测试示例的数量不能被批量大小整除(例如,测试示例的数量是素数),您会怎么做?
附加信息:
在实践中,我们通常通过调用do_evaluation()
函数在学习期间多次评估测试数据.如果您只想处理测试数据一次,Yaroslav的答案很有用.理想情况下,每次调用do_evaluation都会在测试数据集中的每个示例上运行一次.我们需要一些机制来重置批处理器,以便您可以再次单次通过它.这是一些代码.不要使用该limit_epochs
命令.它需要一个不会随机播放的批处理器并指定测试集中的批处理数(如果设置的示例数不能被minibatchsize整除,则这不起作用).该函数返回一个新的操作,用于抓取数据,tf.errors.OutOfRangeError
当你在整个集合上运行时会抛出数据.第二个返回值是应该调用以重置计数器的操作.这应该是do_evaluation()
函数内的第一个调用.
def single_pass(source_batcher,num_batches):
zero = tf.constant(0, dtype=tf.int64)
batch_count = tf.Variable(zero, name="epochs", trainable=False)
limiter = tf.count_up_to(batch_count,num_batches)
with tf.control_dependencies([limiter]):
batcher = tf.identity(source_batcher)
reset = tf.assign(batch_count, zero)
return batcher, reset
Run Code Online (Sandbox Code Playgroud)
您可以为此使用 tf.Data API。就像这样。
import tensorflow as tf
graph = tf.Graph()
sess = tf.Session(graph=graph)
def build_dataset(train_or_test):
if train_or_test == 'test':
dataset = tf.data.Dataset.from_tensor_slices(tf.zeros([4, 2]))
elif train_or_test == 'train':
dataset = tf.data.Dataset.from_tensor_slices(tf.ones([10, 2]))
else:
raise ValueError('wrong name')
batch_size = 3
dataset = dataset.batch(batch_size)
return dataset
def build_inputs():
train_dataset = build_dataset('train')
test_dataset = build_dataset('test')
iterator = tf.data.Iterator.from_structure(
train_dataset.output_types,
train_dataset.output_shapes,)
data = iterator.get_next()
tf.identity(data, name='data')
iterator.make_initializer(train_dataset, name='train_init')
iterator.make_initializer(test_dataset, name='test_init')
def model(inputs):
return tf.add(inputs, 1, name='output')
def build_graph():
with graph.as_default():
build_inputs()
data = graph.get_tensor_by_name('data:0')
model(data)
def train():
train_init = graph.get_operation_by_name('train_init')
sess.run(train_init)
out = graph.get_tensor_by_name('output:0')
while True:
try:
network_out = sess.run(out)
print(network_out.shape)
print(network_out)
except tf.errors.OutOfRangeError:
break
def test():
test_init = graph.get_operation_by_name('test_init')
sess.run(test_init)
out = graph.get_tensor_by_name('output:0')
while True:
try:
network_out = sess.run(out)
print(network_out.shape)
print(network_out)
except tf.errors.OutOfRangeError:
break
def train_loop():
cur_epoch = 0
while cur_epoch < 1:
print('Test epoch')
test()
print('Train epoch')
train()
cur_epoch += 1
def initialise_graph():
with graph.as_default():
sess.run(tf.global_variables_initializer())
build_graph()
initialise_graph()
train_loop()
Run Code Online (Sandbox Code Playgroud)
这将输出:
Test epoch
(3, 2)
[[1. 1.]
[1. 1.]
[1. 1.]]
(1, 2)
[[1. 1.]]
Train epoch
(3, 2)
[[2. 2.]
[2. 2.]
[2. 2.]]
(3, 2)
[[2. 2.]
[2. 2.]
[2. 2.]]
(3, 2)
[[2. 2.]
[2. 2.]
[2. 2.]]
(1, 2)
[[2. 2.]]
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
1782 次 |
最近记录: |