如何使用Dataset API读取变量长度列表的TFRecords文件?

Lio*_*Lai 7 python tensorflow tfrecord

我想使用Tensorflow的数据集API来读取变量长度列表的TFRecords文件.这是我的代码.

def _int64_feature(value):
    # value must be a numpy array.
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def main1():
    # Write an array to TFrecord.
    # a is an array which contains lists of variant length.
    a = np.array([[0, 54, 91, 153, 177],
                 [0, 50, 89, 147, 196],
                 [0, 38, 79, 157],
                 [0, 49, 89, 147, 177],
                 [0, 32, 73, 145]])

    writer = tf.python_io.TFRecordWriter('file')

    for i in range(a.shape[0]): # i = 0 ~ 4
        x_train = a[i]
        feature = {'i': _int64_feature(np.array([i])), 'data': _int64_feature(x_train)}

        # Create an example protocol buffer
        example = tf.train.Example(features=tf.train.Features(feature=feature))

        # Serialize to string and write on the file
        writer.write(example.SerializeToString())

    writer.close()

    # Check TFRocord file.
    record_iterator = tf.python_io.tf_record_iterator(path='file')
    for string_record in record_iterator:
        example = tf.train.Example()
        example.ParseFromString(string_record)

        i = (example.features.feature['i'].int64_list.value)
        data = (example.features.feature['data'].int64_list.value)
        #data = np.fromstring(data_string, dtype=np.int64)
        print(i, data)

    # Use Dataset API to read the TFRecord file.
    def _parse_function(example_proto):
        keys_to_features = {'i'   :tf.FixedLenFeature([], tf.int64),
                            'data':tf.FixedLenFeature([], tf.int64)}
        parsed_features = tf.parse_single_example(example_proto, keys_to_features)
        return parsed_features['i'], parsed_features['data']

    ds = tf.data.TFRecordDataset('file')
    iterator = ds.map(_parse_function).make_one_shot_iterator()
    i, data = iterator.get_next()
    with tf.Session() as sess:
        print(i.eval())
        print(data.eval())
Run Code Online (Sandbox Code Playgroud)

检查TFRecord文件

[0] [0, 54, 91, 153, 177]
[1] [0, 50, 89, 147, 196]
[2] [0, 38, 79, 157]
[3] [0, 49, 89, 147, 177]
[4] [0, 32, 73, 145]
Run Code Online (Sandbox Code Playgroud)

但是当我尝试使用Dataset API读取TFRecord文件时,它显示以下错误.

tensorflow.python.framework.errors_impl.InvalidArgumentError:Name :,Key:data,Index:0.int64值的数量!= expected.值大小:5但输出形状:[]

谢谢.
更新: 我尝试使用以下代码读取数据集API的TFRecord,但它们都失败了.

def _parse_function(example_proto):
    keys_to_features = {'i'   :tf.FixedLenFeature([], tf.int64),
                        'data':tf.VarLenFeature(tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, keys_to_features)
    return parsed_features['i'], parsed_features['data']

ds = tf.data.TFRecordDataset('file')
iterator = ds.map(_parse_function).make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
    print(sess.run([i, data]))
Run Code Online (Sandbox Code Playgroud)

要么

def _parse_function(example_proto):
    keys_to_features = {'i'   :tf.VarLenFeature(tf.int64),
                        'data':tf.VarLenFeature(tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, keys_to_features)
    return parsed_features['i'], parsed_features['data']

ds = tf.data.TFRecordDataset('file')
iterator = ds.map(_parse_function).make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
    print(sess.run([i, data]))
Run Code Online (Sandbox Code Playgroud)

而错误:

回溯(最近一次调用最后一次):文件"/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor_util.py",第468行,在make_tensor_proto中str_values = [compat.as_bytes(x) for proto_values中的x]文件"/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor_util.py",第468行,str_values = [compat.as_bytes(x)for prot in proto_values ]文件"/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/compat.py",第65行,as_bytes(bytes_or_text,))TypeError:预期的二进制或unicode字符串,得到了

在处理上述异常期间,发生了另一个异常:

回溯(最近一次调用最后一次):文件"2tfrecord.py",第126行,在main1()文件"2tfrecord.py",第72行,在main1 iterator = ds.map(_parse_function).make_one_shot_iterator()文件"/ usr /local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py",第712行,在地图中返回MapDataset(self,map_func)文件"/usr/local/lib/python3.5 /dist-packages/tensorflow/python/data/ops/dataset_ops.py",第1385行,在init self._map_func.add_to_graph(ops.get_default_graph())文件"/usr/local/lib/python3.5/dist- packages/tensorflow/python/framework/function.py",第486行,在add_to_graph self._create_definition_if_needed()文件"/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py" ,第321行,在_create_definition_if_needed self._create_definition_if_needed_impl()文件"/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py",第338行,在_create_definition_if_needed_impl outputs = self._func(*输入)文件"/usr/local/lib/python3.5/dist-pack age/tensorflow/python/data/ops/dataset_ops.py",第1376行,在tf_map_func中flattened_ret = [ops.convert_to_tensor(t)for t中的nest.flatten(ret)]文件"/ usr/local/lib/python3. 5/dist-packages/tensorflow/python/data/ops/dataset_ops.py",第1376行,flattened_ret = [ops.convert_to_tensor(t)for nest.flatten(ret)]文件"/ usr/local/lib /python3.5/dist-packages/tensorflow/python/framework/ops.py",第836行,在convert_to_tensor中as_ref = False)文件"/usr/local/lib/python3.5/dist-packages/tensorflow/python/ framework/ops.py",第926行,in internal_convert_to_tensor ret = conversion_func(value,dtype = dtype,name = name,as_ref = as_ref)File"/usr/local/lib/python3.5/dist-packages/tensorflow/python /framework/constant_op.py",第229行,在_constant_tensor_conversion_function中返回常量(v,dtype = dtype,name = name)文件"/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/constant_op .py",第208行,常数值,dtype = dtype,shape = shape,verify_shape = verify_shape))文件"/ us r/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor_util.py",第472行,make_tensor_proto"支持的类型." %(类型(值),值))TypeError:无法将类型的对象转换为Tensor.内容:SparseTensor(indices = Tensor("ParseSingleExample/Slice_Indices_i:0",shape =(?,1),dtype = int64),values = Tensor("ParseSingleExample/ParseExample/ParseExample:3",shape =(?,), dtype = int64),dense_shape = Tensor("ParseSingleExample/Squeeze_Shape_i:0",shape =(1,),dtype = int64)).考虑将元素转换为支持的类型.

Python版本:3.5.2
Tensorflow版本:1.4.1

Lio*_*Lai 12

经过几个小时的搜索和尝试,我相信答案出现了.以下是我的代码.

def _int64_feature(value):
    # value must be a numpy array.
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value.flatten()))

# Write an array to TFrecord.
# a is an array which contains lists of variant length.
a = np.array([[0, 54, 91, 153, 177],
              [0, 50, 89, 147, 196],
              [0, 38, 79, 157],
              [0, 49, 89, 147, 177],
              [0, 32, 73, 145]])

writer = tf.python_io.TFRecordWriter('file')

for i in range(a.shape[0]): # i = 0 ~ 4
    x_train = np.array(a[i])
    feature = {'i'   : _int64_feature(np.array([i])), 
               'data': _int64_feature(x_train)}

    # Create an example protocol buffer
    example = tf.train.Example(features=tf.train.Features(feature=feature))

    # Serialize to string and write on the file
    writer.write(example.SerializeToString())

writer.close()

# Check TFRocord file.
record_iterator = tf.python_io.tf_record_iterator(path='file')
for string_record in record_iterator:
    example = tf.train.Example()
    example.ParseFromString(string_record)

    i = (example.features.feature['i'].int64_list.value)
    data = (example.features.feature['data'].int64_list.value)
    print(i, data)

# Use Dataset API to read the TFRecord file.
filenames = ["file"]
dataset = tf.data.TFRecordDataset(filenames)
def _parse_function(example_proto):
    keys_to_features = {'i':tf.VarLenFeature(tf.int64),
                        'data':tf.VarLenFeature(tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, keys_to_features)
    return tf.sparse_tensor_to_dense(parsed_features['i']), \
           tf.sparse_tensor_to_dense(parsed_features['data'])
# Parse the record into tensors.
dataset = dataset.map(_parse_function)
# Shuffle the dataset
dataset = dataset.shuffle(buffer_size=1)
# Repeat the input indefinitly
dataset = dataset.repeat()  
# Generate batches
dataset = dataset.batch(1)
# Create a one-shot iterator
iterator = dataset.make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
    print(sess.run([i, data]))
    print(sess.run([i, data]))
    print(sess.run([i, data]))
Run Code Online (Sandbox Code Playgroud)

有几点需要注意.
1.本SO问题有很大帮助.
2. tf.VarLenFeature将返回SparseTensor,因此,使用tf.sparse_tensor_to_dense转换为密集张量是必要的.
3.在我的代码中,parse_single_example()不能替换parse_example(),并且它让我烦恼了一天.我不知道为什么parse_example()不运作.如果有人知道原因,请赐教.