为了在Tensorflow中训练LSTM模型,我将数据结构化为tf.train.SequenceExample格式并将其存储到TFRecord文件中.我现在想使用新的DataSet API来生成用于训练的填充批次.在文档中有一个使用padded_batch的例子,但对于我的数据,我无法弄清楚padded_shapes应该是什么值.
为了将TFrecord文件读入批处理,我编写了以下Python代码:
import math
import tensorflow as tf
import numpy as np
import struct
import sys
import array
if(len(sys.argv) != 2):
print "Usage: createbatches.py [RFRecord file]"
sys.exit(0)
vectorSize = 40
inFile = sys.argv[1]
def parse_function_dataset(example_proto):
sequence_features = {
'inputs': tf.FixedLenSequenceFeature(shape=[vectorSize],
dtype=tf.float32),
'labels': tf.FixedLenSequenceFeature(shape=[],
dtype=tf.int64)}
_, sequence = tf.parse_single_sequence_example(example_proto, sequence_features=sequence_features)
length = tf.shape(sequence['inputs'])[0]
return sequence['inputs'], sequence['labels']
sess = tf.InteractiveSession()
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.contrib.data.TFRecordDataset(filenames) …Run Code Online (Sandbox Code Playgroud)