小编Mar*_*gts的帖子

如何使用DataSet API在Tensorflow中为tf.train.SequenceExample数据创建填充批次?

为了在Tensorflow中训练LSTM模型,我将数据结构化为tf.train.SequenceExample格式并将其存储到TFRecord文件中.我现在想使用新的DataSet API来生成用于训练的填充批次.在文档中有一个使用padded_batch的例子,但对于我的数据,我无法弄清楚padded_shapes应该是什么值.

为了将TFrecord文件读入批处理,我编写了以下Python代码:

import math
import tensorflow as tf
import numpy as np
import struct
import sys
import array

if(len(sys.argv) != 2):
  print "Usage: createbatches.py [RFRecord file]"
  sys.exit(0)


vectorSize = 40
inFile = sys.argv[1]

def parse_function_dataset(example_proto):
  sequence_features = {
      'inputs': tf.FixedLenSequenceFeature(shape=[vectorSize],
                                           dtype=tf.float32),
      'labels': tf.FixedLenSequenceFeature(shape=[],
                                           dtype=tf.int64)}

  _, sequence = tf.parse_single_sequence_example(example_proto, sequence_features=sequence_features)

  length = tf.shape(sequence['inputs'])[0]
  return sequence['inputs'], sequence['labels']

sess = tf.InteractiveSession()

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.contrib.data.TFRecordDataset(filenames) …
Run Code Online (Sandbox Code Playgroud)

python lstm tensorflow tensorflow-datasets

13
推荐指数
1
解决办法
9514
查看次数

标签 统计

lstm ×1

python ×1

tensorflow ×1

tensorflow-datasets ×1