解析csv时升级到tf.dataset无法正常工作

ree*_*106 8 tensorflow google-cloud-ml tensorflow-datasets

我有一个GCMLE实验,我正在尝试升级我input_fn以使用新tf.data功能.我已根据此示例创建了以下input_fn

def input_fn(...):
    dataset = tf.data.Dataset.list_files(filenames).shuffle(num_shards) # shuffle up the list of input files
    dataset = dataset.interleave(lambda filename: # mix together records from cycle_length number of shards
                tf.data.TextLineDataset(filename).skip(1).map(lambda row: parse_csv(row, hparams)), cycle_length=5) 
    if shuffle:
      dataset = dataset.shuffle(buffer_size = 10000)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    features = iterator.get_next()

    labels = features.pop(LABEL_COLUMN)

    return features, labels
Run Code Online (Sandbox Code Playgroud)

parse_csv和我之前使用的相同,但目前还没有.我可以解决一些问题,但我不完全理解为什么我遇到这些问题.这是我的parse_csv()函数的开始

def parse_csv(..):
    columns = tf.decode_csv(rows, record_defaults=CSV_COLUMN_DEFAULTS)
    raw_features = dict(zip(FIELDNAMES, columns))

    words = tf.string_split(raw_features['sentences']) # splitting words
    vocab_table = tf.contrib.lookup.index_table_from_file(vocabulary_file = hparams.vocab_file,
                default_value = 0)

....
Run Code Online (Sandbox Code Playgroud)
  1. 马上tf.string_split()停止工作,错误是ValueError: Shape must be rank 1 but is rank 0 for 'csv_preprocessing/input_sequence_generation/StringSplit' (op: 'StringSplit') with input shapes: [], [].- 这很容易通过打包raw_features['sentences']到张量来解决,[raw_features['sentences']]但我不明白为什么这种方法需要这个dataset?旧版本如何才能正常工作?对于与我的模型的其余部分匹配的形状,我最终需要在末尾删除这个额外的维度,words = tf.squeeze(words, 0)因为我将这个"不必要的"维度添加到张量.

  2. 无论出于何种原因,我也得到一个错误,表没有初始化,tensorflow.python.framework.errors_impl.FailedPreconditionError: Table not initialized.但是,这个代码与我的旧代码完全正常input_fn()(见下文)所以我不知道为什么我现在需要初始化表?我还没有想出这个部分的解决方案.tf.contrib.lookup.index_table_from_file在我的parse_csv函数中是否有任何我无法使用的东西?

作为参考,这是我的旧input_fn()仍然可以工作:

def input_fn(...):
    filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once(filenames), 
                num_epochs=num_epochs, shuffle=shuffle, capacity=32)
    reader = tf.TextLineReader(skip_header_lines=skip_header_lines)

    _, rows = reader.read_up_to(filename_queue, num_records=batch_size)

    features = parse_csv(rows, hparams)


        if shuffle:
            features = tf.train.shuffle_batch(
                features,
                batch_size,
                min_after_dequeue=2 * batch_size + 1,
                capacity=batch_size * 10,
                num_threads=multiprocessing.cpu_count(), 
                enqueue_many=True,
                allow_smaller_final_batch=True
            )
        else:
            features = tf.train.batch(
                features,
                batch_size,
                capacity=batch_size * 10,
                num_threads=multiprocessing.cpu_count(),
                enqueue_many=True,
                allow_smaller_final_batch=True
            )

labels = features.pop(LABEL_COLUMN)

return features, labels
Run Code Online (Sandbox Code Playgroud)

更新TF 1.7

我正在重新审视TF 1.7(它应该具有@mrry中提到的所有TF 1.6功能)但我仍然无法复制该行为.对于我的老年人,input_fn()我能够以13步/秒为单位.我正在使用的新功能如下:

def input_fn(...):
    files = tf.data.Dataset.list_files(filenames).shuffle(num_shards)
    dataset = files.apply(tf.contrib.data.parallel_interleave(lambda filename: tf.data.TextLineDataset(filename).skip(1), cycle_length=num_shards))
    dataset = dataset.apply(tf.contrib.data.map_and_batch(lambda row:
            parse_csv_dataset(row, hparams = hparams), 
            batch_size = batch_size, 
            num_parallel_batches = multiprocessing.cpu_count())) 
    dataset = dataset.prefetch(1)
    if shuffle:
        dataset = dataset.shuffle(buffer_size = 10000)
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_initializable_iterator()
    features = iterator.get_next()
    tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)

    labels = {key: features.pop(key) for key in LABEL_COLUMNS}

    return features, labels 
Run Code Online (Sandbox Code Playgroud)

我相信我遵循所有的性能guildines,例如1)使用prefetch 2)使用map_and_batch和num_parallel_batches = cores 3)使用parallel_interleave 4)在重复之前应用shuffle.我没有使用的唯一步骤是缓存建议,但是期望它真的只能帮助第一个以外的时期以及"先应用交错,预取和随机播放". - 但是我发现在map_and_batch之后有预取和shuffle加速~10%.

BUFFER ISSUE 我注意到的第一个性能问题是我的旧版本input_fn()花了大约13个挂钟分钟来完成20k步骤,然而即使buffer_size为10,000(我认为我们要等到我们有10,000个批次)处理)我还在等待超过40分钟的缓冲区才能满了.这么长时间有意义吗?如果我知道我的GCS碎片化的.csv的已经随机的,是可以接受的这种洗牌/缓冲区的大小更小?我试图从tf.train.shuffle_batch()复制行为 - 然而,似乎在最坏的情况下,为了填满缓冲区,它需要花费相同的13分钟达到10k步骤?

步/秒

即使缓冲区已经填满,全局步数/秒也会在同一模型上达到大约3步/秒(通常低至2步/秒),前一个input_fn()达到~13步/秒.

SLOPPY INTERLEAVE 我试图用sloppy_interleave()替换parallel_interleave(),因为这是@mrry的另一个建议.当我切换到sloppy_interleave时,我得到了14步/秒!我知道这意味着它是不确定的,但应该真的只是意味着它是不是从一个运行(或纪元)确定下一个?或者对此有更大的影响?我应该关注旧shuffle_batch()方法和sloppy_interleave 之间的任何真正区别吗?事实上,这导致4-5倍的改善表明之前的阻塞因素是什么?

ree*_*106 2

在 TF 1.4(目前是与 GCMLE 配合使用的 TF 的最新版本)中,您将无法使用make_one_shot_iterator()查找表(请参阅相关帖子),您需要使用查找表Dataset.make_initializable_iterator(),然后iterator.initalizer使用默认值进行初始化TABLES_INITIALIZER(来自这篇文章)。应该是这样input_fn()的:

def input_fn(...):
  dataset = tf.data.Dataset.list_files(filenames).shuffle(num_shards)

  # Define `vocab_table` outside the map function and use it in `parse_csv()`.
  vocab_table = tf.contrib.lookup.index_table_from_file(
      vocabulary_file=hparams.vocab_file, default_value=0)

  dataset = dataset.interleave(
      lambda filename: (tf.data.TextLineDataset(filename)
                        .skip(1)
                        .map(lambda row: parse_csv(row, hparams),
                             num_parallel_calls=multiprocessing.cpu_count())),
      cycle_length=5) 

  if shuffle:
    dataset = dataset.shuffle(buffer_size=10000)
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size)
  iterator = dataset.make_initializable_iterator()
  features = iterator.get_next()

  # add iterator.intializer to be handled by default table initializers
  tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer) 

  labels = features.pop(LABEL_COLUMN)

  return features, labels
Run Code Online (Sandbox Code Playgroud)