将.tfrecords文件拆分为多个.tfrecords文件

chr*_*stk 3 python tensorflow tfrecord tensorflow-datasets

有什么方法可以直接将.tfrecords文件拆分为多个.tfrecords文件,而无需回写每个Dataset示例?

小智 13

在 tensorflow 2.0.0 中,这将起作用:

import tensorflow as tf

raw_dataset = tf.data.TFRecordDataset("input_file.tfrecord")

shards = 10

for i in range(shards):
    writer = tf.data.experimental.TFRecordWriter(f"output_file-part-{i}.tfrecord")
    writer.write(raw_dataset.shard(shards, i))
Run Code Online (Sandbox Code Playgroud)


jde*_*esa 6

您可以使用如下功能:

import tensorflow as tf

def split_tfrecord(tfrecord_path, split_size):
    with tf.Graph().as_default(), tf.Session() as sess:
        ds = tf.data.TFRecordDataset(tfrecord_path).batch(split_size)
        batch = ds.make_one_shot_iterator().get_next()
        part_num = 0
        while True:
            try:
                records = sess.run(batch)
                part_path = tfrecord_path + '.{:03d}'.format(part_num)
                with tf.python_io.TFRecordWriter(part_path) as writer:
                    for record in records:
                        writer.write(record)
                part_num += 1
            except tf.errors.OutOfRangeError: break
Run Code Online (Sandbox Code Playgroud)

例如,要将文件my_records.tfrecord分成100条记录的一部分,您可以执行以下操作:

split_tfrecord(my_records.tfrecord, 100)
Run Code Online (Sandbox Code Playgroud)

这将创建多个较小的记录文件my_records.tfrecord.000my_records.tfrecord.001等等。