CNu*_*ren 17 python tensorflow
我正在将TensorFlow代码从旧队列接口更改为新的Dataset API.使用旧接口,我可以指定队列的num_threads
参数tf.train.shuffle_batch
.但是,控制数据集API中线程数量的唯一方法似乎是map
使用num_parallel_calls
参数的函数.但是,我正在使用该flat_map
函数,它没有这样的参数.
问题:有没有办法控制flat_map
函数的线程/进程数?或者是否可以map
结合使用flat_map
并仍然指定并行呼叫的数量?
请注意,并行运行多个线程至关重要,因为我打算在数据进入队列之前在CPU上运行大量预处理.
GitHub上有两个(这里和这里)相关的帖子,但我不认为他们回答了这个问题.
这是我用例的最小代码示例:
with tf.Graph().as_default():
data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
input_tensors = (data,)
def pre_processing_func(data_):
# normally I would do data-augmentation here
results = (tf.expand_dims(data_, axis=0),)
return tf.data.Dataset.from_tensor_slices(results)
dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
dataset = dataset_source.flat_map(pre_processing_func)
# do something with 'dataset'
Run Code Online (Sandbox Code Playgroud)
GPh*_*ilo 11
据我所知,目前flat_map
不提供并行选项.鉴于大部分计算都已完成pre_processing_func
,您可以使用的解决方法是并行map
调用,然后进行一些缓冲,然后使用flat_map
带有标识lambda函数的调用来处理输出的扁平化.
在代码中:
NUM_THREADS = 5
BUFFER_SIZE = 1000
def pre_processing_func(data_):
# data-augmentation here
# generate new samples starting from the sample `data_`
artificial_samples = generate_from_sample(data_)
return atificial_samples
dataset_source = (tf.data.Dataset.from_tensor_slices(input_tensors).
map(pre_processing_func, num_parallel_calls=NUM_THREADS).
prefetch(BUFFER_SIZE).
flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)).
shuffle(BUFFER_SIZE)) # my addition, probably necessary though
Run Code Online (Sandbox Code Playgroud)
由于pre_processing_func
从初始样本开始生成任意数量的新样本(以形状矩阵组织(?, 512)
),因此flat_map
需要调用将所有生成的矩阵转换为Dataset
包含单个样本的s(因此tf.data.Dataset.from_tensor_slices(x)
在lambda中),然后将所有这些数据集展平为一大Dataset
包含个别样本.
.shuffle()
这个数据集可能是一个好主意,或者生成的样本将被打包在一起.