gol*_*enk 31 python tensorflow tfrecord
您如何一次性阅读TFRecords中的所有示例?
我一直在使用tf.parse_single_example类似于fully_connected_reader示例中的方法read_and_decode中给出的代码来读出单个示例.但是,我想立即针对我的整个验证数据集运行网络,因此希望完全加载它们.
我不完全确定,但文档似乎建议我可以使用tf.parse_example而不是tf.parse_single_example一次加载整个TFRecords文件.我似乎无法让这个工作.我猜它与我如何指定功能有关,但我不确定在功能规范中如何说明有多个例子.
换句话说,我尝试使用类似的东西:
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_example(serialized_example, features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
Run Code Online (Sandbox Code Playgroud)
不起作用,我认为这是因为这些功能不会同时出现多个例子(但同样,我不确定).[导致错误ValueError: Shape () must have rank 1]
这是一次读取所有记录的正确方法吗?如果是这样,我需要更改什么来实际读取记录?非常感谢!
And*_*rno 20
为了清楚起见,我在一个.tfrecords文件中有几千张图像,它们是720 x 720 rgb png文件.标签是0,1,2,3之一.
我也尝试使用parse_example并且无法使其工作,但此解决方案适用于parse_single_example.
缺点是现在我必须知道每个.tf记录中有多少项,这有点令人失望.如果我找到更好的方法,我会更新答案.另外,小心超出.tfrecords文件中记录数量的界限,如果循环遍历最后一条记录,它将从第一条记录开始
诀窍是让队列运行器使用协调器.
我在这里留下了一些代码,以便在读取图像时保存图像,以便您可以验证图像是否正确.
from PIL import Image
import numpy as np
import tensorflow as tf
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
label = tf.cast(features['label'], tf.int32)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
depth = tf.cast(features['depth'], tf.int32)
return image, label, height, width, depth
def get_all_records(FILE):
with tf.Session() as sess:
filename_queue = tf.train.string_input_producer([ FILE ])
image, label, height, width, depth = read_and_decode(filename_queue)
image = tf.reshape(image, tf.pack([height, width, 3]))
image.set_shape([720,720,3])
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(2053):
example, l = sess.run([image, label])
img = Image.fromarray(example, 'RGB')
img.save( "output/" + str(i) + '-train.png')
print (example,l)
coord.request_stop()
coord.join(threads)
get_all_records('/path/to/train-0.tfrecords')
Run Code Online (Sandbox Code Playgroud)
syg*_*ygi 12
要只读取一次所有数据,您需要传递num_epochs给string_input_producer.当读取所有记录时,阅读.read器的方法将抛出错误,您可以捕获.简化示例:
import tensorflow as tf
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
return image
def get_all_records(FILE):
with tf.Session() as sess:
filename_queue = tf.train.string_input_producer([FILE], num_epochs=1)
image = read_and_decode(filename_queue)
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while True:
example = sess.run([image])
except tf.errors.OutOfRangeError, e:
coord.request_stop(e)
finally:
coord.request_stop()
coord.join(threads)
get_all_records('/path/to/train-0.tfrecords')
Run Code Online (Sandbox Code Playgroud)
要使用tf.parse_example(比这更快tf.parse_single_example),您需要首先批处理示例:
batch = tf.train.batch([serialized_example], num_examples, capacity=num_examples)
parsed_examples = tf.parse_example(batch, feature_spec)
Run Code Online (Sandbox Code Playgroud)
不幸的是,这样你需要事先了解一些例子.
Sal*_*ali 12
如果您需要立即从TFRecord读取所有数据,您可以使用tf_record_iterator在几行代码中编写更简单的解决方案:
从TFRecords文件读取记录的迭代器.
要做到这一点,你只需:
这是一个解释如何阅读每种类型的示例.
example = tf.train.Example()
for record in tf.python_io.tf_record_iterator(<tfrecord_file>):
example.ParseFromString(record)
f = example.features.feature
v1 = f['int64 feature'].int64_list.value[0]
v2 = f['float feature'].float_list.value[0]
v3 = f['bytes feature'].bytes_list.value[0]
# for bytes you might want to represent them in a different way (based on what they were before saving)
# something like `np.fromstring(f['img'].bytes_list.value[0], dtype=np.uint8
# Now do something with your v1/v2/v3
Run Code Online (Sandbox Code Playgroud)
小智 9
您还可以使用tf.python_io.tf_record_iterator手动迭代a中的所有示例TFRecord.
我用下面的插图代码测试它:
import tensorflow as tf
X = [[1, 2],
[3, 4],
[5, 6]]
def _int_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def dump_tfrecord(data, out_file):
writer = tf.python_io.TFRecordWriter(out_file)
for x in data:
example = tf.train.Example(
features=tf.train.Features(feature={
'x': _int_feature(x)
})
)
writer.write(example.SerializeToString())
writer.close()
def load_tfrecord(file_name):
features = {'x': tf.FixedLenFeature([2], tf.int64)}
data = []
for s_example in tf.python_io.tf_record_iterator(file_name):
example = tf.parse_single_example(s_example, features=features)
data.append(tf.expand_dims(example['x'], 0))
return tf.concat(0, data)
if __name__ == "__main__":
dump_tfrecord(X, 'test_tfrecord')
print('dump ok')
data = load_tfrecord('test_tfrecord')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
Y = sess.run([data])
print(Y)
Run Code Online (Sandbox Code Playgroud)
当然,你必须使用自己的feature规范.
缺点是我不知道如何以这种方式使用多线程.但是,我们阅读所有示例的最多时机是我们评估验证数据集时,通常不是很大.所以我认为效率可能不是瓶颈.
当我测试这个问题时,我还有另一个问题,那就是我必须指定特征长度.而不是tf.FixedLenFeature([], tf.int64),我必须写tf.FixedLenFeature([2], tf.int64),否则,InvalidArgumentError发生.我不知道如何避免这种情况.
Python:3.4
Tensorflow:0.12.0
| 归档时间: |
|
| 查看次数: |
31712 次 |
| 最近记录: |