如何将Float数组/列表转换为TFRecord?

Thi*_*ien 3 python tensorflow

这是用于将数据转换为TFRecord的代码

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

 def _bytes_feature(value):
   return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _floats_feature(value):
   return tf.train.Feature(float_list=tf.train.FloatList(value=value))

with tf.python_io.TFRecordWriter("train.tfrecords") as writer:
    for row in train_data:
        prices, label, pip = row[0],row[1],row[2]
        prices = np.asarray(prices).astype(np.float32)
        example = tf.train.Example(features=tf.train.Features(feature={
                                           'prices': _floats_feature(prices),
                                           'label': _int64_feature(label[0]),
                                           'pip': _floats_feature(pip)
    }))
        writer.write(example.SerializeToString())
Run Code Online (Sandbox Code Playgroud)

特征价格是一个形状阵列(1,288).它转换成功!但是当使用解析函数和Dataset API解码数据时.

def parse_func(serialized_data):
    keys_to_features = {'prices': tf.FixedLenFeature([], tf.float32),
                    'label': tf.FixedLenFeature([], tf.int64)}

    parsed_features = tf.parse_single_example(serialized_data, keys_to_features)
    return parsed_features['prices'],tf.one_hot(parsed_features['label'],2)
Run Code Online (Sandbox Code Playgroud)

它给了我错误

C:\ tf_jenkins\workspace\rel-win\M\windows -gpu\PY\36\tensorflow\core\framework\op_kernel.cc:1202] OP_REQUIRES在example_parsing_ops.cc:240失败:无效参数:密钥:价格.无法解析序列化的示例.2018-03-31 15:37:11.443073:WC:\ tf_jenkins\workspace\rel-win\M\windows -gpu\PY\36\tensorflow\core\framework\op_kernel.cc:1202] OP_REQUIRES在example_parsing_ops.cc失败:240:参数无效:关键:价格.无法解析序列化的示例.2018-03-31 15:37:11.443313:WC:\ tf_jenkins\workspace\rel-win\M\windows-gpu\raise type(e)(node_def,op,message)PY\36\tensortensorflow.python.framework. errors_impl.InvalidArgumentError:键:价格.无法解析序列化的示例.[[Node:ParseSingleExample/ParseSingleExample = ParseSingleExample [Tdense = [DT_INT64,DT_FLOAT],dense_keys = ["label","prices"],dense_shapes = [[],[]],num_sparse = 0,sparse_keys = [],sparse_types = []](arg0,ParseSingleExample/Const,ParseSingleExample/Const_1)]] [[Node:IteratorGetNext_1 = IteratorGetNextoutput_shapes = [[?],[?,2]],output_types = [DT_FLOAT,DT_FLOAT],_ device ="/ job :localhost/replica:0/task:0/device:CPU:0"]] fl ow\core\framework\op_kernel.cc:1202] OP_REQUIRES在example_parsing_ops.cc:240失败:无效参数:密钥:价格.无法解析序列化的示例.

Thi*_*ien 10

我发现了这个问题.不使用tf.FixedLenFeature来解析数组,而是使用tf.FixedLenSequenceFeature

  • @SajadNorouzi在下面有一个答案似乎更正确。我设法使两种方法都能工作。但是,我不确定文档是否如他所说的那样清晰。可能在此期间已对其进行了编辑,但似乎只能_暗示_FixedLenSequenceFeature仅应用于2维或更高维度。可能值得编辑此答案以提及另一个答案,或者找出两种方法中哪一种是真正正确的,或者两者是否正确。干杯! (2认同)