小编CNu*_*ren的帖子

使用TensorFlow Dataset API和flat_map的并行线程

我正在将TensorFlow代码从旧队列接口更改为新的Dataset API.使用旧接口,我可以指定队列的num_threads参数tf.train.shuffle_batch.但是,控制数据集API中线程数量的唯一方法似乎是map使用num_parallel_calls参数的函数.但是,我正在使用该flat_map函数,它没有这样的参数.

问题:有没有办法控制flat_map函数的线程/进程数?或者是否可以map结合使用flat_map并仍然指定并行呼叫的数量?

请注意,并行运行多个线程至关重要,因为我打算在数据进入队列之前在CPU上运行大量预处理.

GitHub上有两个(这里这里)相关的帖子,但我不认为他们回答了这个问题.

这是我用例的最小代码示例:

with tf.Graph().as_default():
    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)

    def pre_processing_func(data_):
        # normally I would do data-augmentation here
        results = (tf.expand_dims(data_, axis=0),)
        return tf.data.Dataset.from_tensor_slices(results)

    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    # do something with 'dataset'
Run Code Online (Sandbox Code Playgroud)

python tensorflow

17
推荐指数
1
解决办法
3894
查看次数

使用TensorFlow Dataset API的Epoch计数器

我正在将TensorFlow代码从旧队列接口更改为新的Dataset API.在我的旧代码中,我通过tf.Variable每次在队列中访问和处理新的输入张量时递增a来跟踪纪元数.我想用新的Dataset API来计算这个时代,但是我在使用它时遇到了一些麻烦.

由于我在预处理阶段生成了可变数量的数据项,因此在训练循环中递增(Python)计数器并不是一件简单的事情 - 我需要根据输入来计算epoch计数.队列或数据集.

我使用旧的队列系统模仿我以前所拥有的东西,这就是我最终得到的数据集API(简化示例):

with tf.Graph().as_default():

    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)

    epoch_counter = tf.Variable(initial_value=0.0, dtype=tf.float32,
                                trainable=False)

    def pre_processing_func(data_):
        data_size = tf.constant(0.1, dtype=tf.float32)
        epoch_counter_op = tf.assign_add(epoch_counter, data_size)
        with tf.control_dependencies([epoch_counter_op]):
            # normally I would do data-augmentation here
            results = (tf.expand_dims(data_, axis=0),)
            return tf.data.Dataset.from_tensor_slices(results)

    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    dataset = dataset.repeat()
    # ... do something with 'dataset' and print
    # the value of 'epoch_counter' every once a while
Run Code Online (Sandbox Code Playgroud)

但是,这不起作用.它崩溃了一个神秘的错误信息:

 TypeError: …
Run Code Online (Sandbox Code Playgroud)

python tensorflow

9
推荐指数
1
解决办法
2090
查看次数

标签 统计

python ×2

tensorflow ×2