mat*_*cey 25 tensorflow tensorflow-datasets
我有一个非常简单的输入管道,from_generator
非常适合......
dataset = tf.data.Dataset.from_generator(complex_img_label_generator,
(tf.int32, tf.string))
dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()
Run Code Online (Sandbox Code Playgroud)
其中complex_img_label_generator
动态生成图像,并返回表示一个numpy的阵列(H, W, 3)
图像和一个简单的string
标签.处理不是我可以表示从文件和tf.image
操作中读取的内容.
我的问题是关于如何平衡发电机?我如何让N个这些生成器在自己的线程中运行.
一个想法是使用dataset.map
与num_parallel_calls
处理线程; 但是地图在张量上运行......另一个想法是创建多个生成器,每个生成器都有自己的,prefetch
并以某种方式加入它们,但我看不出我如何加入N个生成器流?
我可以遵循任何规范的例子吗?
mat*_*cey 21
事实证明,Dataset.map
如果我让生成器超级轻量级(仅生成元数据),然后将实际的重度照明移动到无状态函数,我可以使用.通过这种方式,我可以.map
使用a 来平行重型提升部件py_func
.
作品; 但感觉有点笨拙...很高兴能够添加num_parallel_calls
到from_generator
:)
def pure_numpy_and_pil_complex_calculation(metadata, label):
# some complex pil and numpy work nothing to do with tf
...
dataset = tf.data.Dataset.from_generator(lightweight_generator,
output_types=(tf.string, # metadata
tf.string)) # label
def wrapped_complex_calulation(metadata, label):
return tf.py_func(func = pure_numpy_and_pil_complex_calculation,
inp = (metadata, label),
Tout = (tf.uint8, # (H,W,3) img
tf.string)) # label
dataset = dataset.map(wrapped_complex_calulation,
num_parallel_calls=8)
dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()
Run Code Online (Sandbox Code Playgroud)
我正在from_indexable
为tf.data.Dataset
https://github.com/tensorflow/tensorflow/issues/14448设计
这样做的好处from_indexable
是可以并行化,而python生成器则不能并行化。
函数from_indexable
使一个tf.data.range
,将可索引的索引包装在一个泛型中tf.py_func
并调用map。
对于那些现在想要a的人from_indexable
,这里是lib代码
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import nest
def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
def decorator(func):
def call(*args):
nonlocal output_shapes
flat_output_types = nest.flatten(output_types)
flat_values = tf.py_func(
func,
inp=args,
Tout=flat_output_types,
stateful=stateful, name=name
)
if output_shapes is not None:
# I am not sure if this is nessesary
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
flattened_shapes = nest.flatten_up_to(output_types, output_shapes)
for ret_t, shape in zip(flat_values, flattened_shapes):
ret_t.set_shape(shape)
return nest.pack_sequence_as(output_types, flat_values)
return call
return decorator
def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
ds = tf.data.Dataset.range(len(iterator))
@py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
def index_to_entry(index):
return iterator[index]
return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)
Run Code Online (Sandbox Code Playgroud)
这是一个示例(注意:from_indexable
具有num_parallel_calls argument
)
class PyDataSet:
def __len__(self):
return 20
def __getitem__(self, item):
return np.random.normal(size=(item+1, 10))
ds = from_indexable(PyDataSet(), output_types=tf.float64, output_shapes=[None, 10])
it = ds.make_one_shot_iterator()
entry = it.get_next()
with tf.Session() as sess:
print(sess.run(entry).shape)
print(sess.run(entry).shape)
Run Code Online (Sandbox Code Playgroud)
更新 2018年6月10日:由于https://github.com/tensorflow/tensorflow/pull/15121被合并,因此代码from_indexable
简化为:
import tensorflow as tf
def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
def decorator(func):
def call(*args, **kwargs):
return tf.contrib.framework.py_func(
func=func,
args=args, kwargs=kwargs,
output_types=output_types, output_shapes=output_shapes,
stateful=stateful, name=name
)
return call
return decorator
def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
ds = tf.data.Dataset.range(len(iterator))
@py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
def index_to_entry(index):
return iterator[index]
return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)
Run Code Online (Sandbox Code Playgroud)
将在 中完成的工作限制在generator
最低限度并使用 a 并行化昂贵的处理map
是明智的。
或者,您可以使用parallel_interleave
以下方法“加入”多个生成器:
定义生成器(n): # 返回第 n 个生成器函数 定义数据集(n): 返回 tf.data.Dataset.from_generator(generator(n)) ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset,cycle_lenght=N)) # 其中 N 是您使用的生成器数量
归档时间: |
|
查看次数: |
9029 次 |
最近记录: |