如何使用DataSet API在Tensorflow中为tf.train.SequenceExample数据创建填充批次?

Mar*_*gts 13 python lstm tensorflow tensorflow-datasets

为了在Tensorflow中训练LSTM模型,我将数据结构化为tf.train.SequenceExample格式并将其存储到TFRecord文件中.我现在想使用新的DataSet API来生成用于训练的填充批次.在文档中有一个使用padded_batch的例子,但对于我的数据,我无法弄清楚padded_shapes应该是什么值.

为了将TFrecord文件读入批处理,我编写了以下Python代码:

import math
import tensorflow as tf
import numpy as np
import struct
import sys
import array

if(len(sys.argv) != 2):
  print "Usage: createbatches.py [RFRecord file]"
  sys.exit(0)


vectorSize = 40
inFile = sys.argv[1]

def parse_function_dataset(example_proto):
  sequence_features = {
      'inputs': tf.FixedLenSequenceFeature(shape=[vectorSize],
                                           dtype=tf.float32),
      'labels': tf.FixedLenSequenceFeature(shape=[],
                                           dtype=tf.int64)}

  _, sequence = tf.parse_single_sequence_example(example_proto, sequence_features=sequence_features)

  length = tf.shape(sequence['inputs'])[0]
  return sequence['inputs'], sequence['labels']

sess = tf.InteractiveSession()

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parse_function_dataset)
# dataset = dataset.batch(1)
dataset = dataset.padded_batch(4, padded_shapes=[None])
iterator = dataset.make_initializable_iterator()

batch = iterator.get_next()

# Initialize `iterator` with training data.
training_filenames = [inFile]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})

print(sess.run(batch))
Run Code Online (Sandbox Code Playgroud)

如果我使用dataset = dataset.batch(1)(在这种情况下不需要填充),代码效果很好,但是当我使用padded_batch变量时,我收到以下错误:

TypeError:如果浅层结构是序列,则输入也必须是序列.输入有类型:.

你能帮我弄清楚我应该为padded_shapes参数传递什么吗?

(我知道有很多使用线程和队列的示例代码,但我宁愿在这个项目中使用新的DataSet API)

小智 9

你需要传递一个形状元组.在你的情况下你应该通过

dataset = dataset.padded_batch(4, padded_shapes=([vectorSize],[None]))
Run Code Online (Sandbox Code Playgroud)

或尝试

dataset = dataset.padded_batch(4, padded_shapes=([None],[None]))
Run Code Online (Sandbox Code Playgroud)

检查此代码以获取更多详细信息 我不得不调试这个方法来弄清楚它为什么不适合我.

  • 作为补充,`padded_shapes` 对嵌套结构的类型很敏感(如果数据集返回一个元组,padded_shapes 也应该是一个元组而不是一个列表) (3认同)