小编bal*_*lto的帖子

读取张量流中的大型列车/验证/测试数据集

将多个大数据集加载到张量流中的正确方法是什么?

我有三个大数据集(文件),分别用于训练,验证和测试.我可以通过tf.train.string_input_producer成功加载训练集,并将其输入到tf.train.shuffle_batch对象中.然后我可以迭代地获取批量数据以优化我的模型.

但是,当我尝试以同样的方式加载我的验证集时,我遇到了困难,程序一直说"OutOfRange Error",即使我没有在string_input_producer中设置num_epochs.

任何人都可以点亮它吗?除此之外,我还在考虑在tensorflow中进行训练/验证的正确方法是什么?实际上,我没有看到任何在大数据集上进行训练和测试的例子(我经常搜索).这对我来说太奇怪了......

下面的代码片段.

def extract_validationset(filename, batch_size):
  with tf.device("/cpu:0"):
    queue = tf.train.string_input_producer([filename])
    reader = tf.TextLineReader()
    _, line = reader.read(queue)

    line = tf.decode_csv(...)
    label = line[0]
    feature = tf.pack(list(line[1:]))

    l, f = tf.train.batch([label, feature], batch_size=batch_size, num_threads=8)
    return l, f

def extract_trainset(train, batch_size):
  with tf.device("/cpu:0"):
    train_files = tf.train.string_input_producer([train])
    reader = tf.TextLineReader()
    _, train_line = reader.read(train_files)

    train_line = tf.decode_csv(...)

    l, f = tf.train.shuffle_batch(...,
  batch_size=batch_size, capacity=50000, min_after_dequeue=10000,  num_threads=8)
  return l, f

....

label_batch, feature_batch = extract_trainset("train", batch_size)
label_eval, feature_eval = extract_validationset("test", batch_size) …
Run Code Online (Sandbox Code Playgroud)

tensorflow

9
推荐指数
1
解决办法
2674
查看次数

标签 统计

tensorflow ×1