我正在将TensorFlow代码从旧队列接口更改为新的Dataset API.使用旧接口,我可以指定队列的num_threads
参数tf.train.shuffle_batch
.但是,控制数据集API中线程数量的唯一方法似乎是map
使用num_parallel_calls
参数的函数.但是,我正在使用该flat_map
函数,它没有这样的参数.
问题:有没有办法控制flat_map
函数的线程/进程数?或者是否可以map
结合使用flat_map
并仍然指定并行呼叫的数量?
请注意,并行运行多个线程至关重要,因为我打算在数据进入队列之前在CPU上运行大量预处理.
GitHub上有两个(这里和这里)相关的帖子,但我不认为他们回答了这个问题.
这是我用例的最小代码示例:
with tf.Graph().as_default():
data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
input_tensors = (data,)
def pre_processing_func(data_):
# normally I would do data-augmentation here
results = (tf.expand_dims(data_, axis=0),)
return tf.data.Dataset.from_tensor_slices(results)
dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
dataset = dataset_source.flat_map(pre_processing_func)
# do something with 'dataset'
Run Code Online (Sandbox Code Playgroud) 我正在将TensorFlow代码从旧队列接口更改为新的Dataset API.在我的旧代码中,我通过tf.Variable
每次在队列中访问和处理新的输入张量时递增a来跟踪纪元数.我想用新的Dataset API来计算这个时代,但是我在使用它时遇到了一些麻烦.
由于我在预处理阶段生成了可变数量的数据项,因此在训练循环中递增(Python)计数器并不是一件简单的事情 - 我需要根据输入来计算epoch计数.队列或数据集.
我使用旧的队列系统模仿我以前所拥有的东西,这就是我最终得到的数据集API(简化示例):
with tf.Graph().as_default():
data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
input_tensors = (data,)
epoch_counter = tf.Variable(initial_value=0.0, dtype=tf.float32,
trainable=False)
def pre_processing_func(data_):
data_size = tf.constant(0.1, dtype=tf.float32)
epoch_counter_op = tf.assign_add(epoch_counter, data_size)
with tf.control_dependencies([epoch_counter_op]):
# normally I would do data-augmentation here
results = (tf.expand_dims(data_, axis=0),)
return tf.data.Dataset.from_tensor_slices(results)
dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
dataset = dataset_source.flat_map(pre_processing_func)
dataset = dataset.repeat()
# ... do something with 'dataset' and print
# the value of 'epoch_counter' every once a while
Run Code Online (Sandbox Code Playgroud)
但是,这不起作用.它崩溃了一个神秘的错误信息:
TypeError: …
Run Code Online (Sandbox Code Playgroud)