如何在 tf.data API 中使用 Keras 生成器

sib*_*iby 6 python keras tensorflow tensorflow-datasets

我正在尝试使用 Keras 预处理库中的生成器。我想对此进行试验,因为 Keras 为图像增强提供了很好的功能。但是,我不确定这是否真的可能。

以下是我从 Keras 生成器制作 tf 数据集的方法:

def make_generator():
    train_datagen = ImageDataGenerator(rescale=1. / 255)
    train_generator = 
    train_datagen.flow_from_directory(train_dataset_folder,target_size=(224, 224), class_mode='categorical', batch_size=32)
    return train_generator

train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32)).shuffle(64).repeat().batch(32)
Run Code Online (Sandbox Code Playgroud)

请注意,如果您尝试直接将其train_generator作为参数提供给tf.data.Dataset.from_generator,则会出现错误。但是,上述方法不会产生错误。

当我在会话中运行它以检查数据集的输出时,我收到以下错误。

iterator = train_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
for i in range(100):
    sess.run(next_element)
Run Code Online (Sandbox Code Playgroud)

找到属于 2 个类别的 1000 张图像。-------------------------------------------------- ------------------------- InvalidArgumentError Traceback (最近一次调用最后一次) /usr/local/lib/python3.6/dist-packages/tensorflow/ python/client/session.py in _do_call(self, fn, *args) 1291 try: -> 1292 return fn(*args) 1293 除了 errors.OpError as e:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata) 1276 return self._call_tf_sessionrun( -> 1277 options, feed_dict,第 1278 章

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata) 1366 self._session, options, feed_dict, fetch_list,目标列表,-> 1367 运行元数据)1368

InvalidArgumentError:无法批量处理组件 0 中具有不同形状的张量。第一个元素的形状为 [32,224,224,3],元素 29 的形状为 [8,224,224,3]。[[{{node IteratorGetNext_2}} = IteratorGetNextoutput_shapes=[, ], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

在处理上述异常的过程中,又发生了一个异常:

如果有人对此有任何经验或知道任何其他方法,请告诉我。

更新

使用 JEK 的建议后,我能够解决问题

train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32))
Run Code Online (Sandbox Code Playgroud)

但是,当我train_dataset使用 Keras.fit方法时,出现以下错误。

model_regular.fit(train_dataset,steps_per_epoch=1000,epochs=2)
Run Code Online (Sandbox Code Playgroud)

-------------------------------------------------- ------------------------- ValueError Traceback (最近一次调用最后一次) in () ----> 1 model_regular.fit(train_dataset,steps_per_epoch= 1000,epochs=2)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight , sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs) 1507 steps_name='steps_per_epoch', 1508 steps=steps_per_epoch, -> 1509 validation_split=validation_split) 1510 1511 # 准备验证数据。

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in _standardize_user_data(self,x,y,sample_weight,class_weight,batch_size,check_steps,steps_name,steps,validation_split)948 x = self._dataset_iterator_cache[x] 949 else: --> 950 iterator = x.make_initializable_iterator() 951 self._dataset_iterator_cache[x] = iterator 952 x = iterator

/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py in make_initializable_iterator(self, shared_name) 119 with ops.colocate_with(iterator_resource): 120 initializer = gen_dataset_ops.make_iterator(self) ._as_variant_tensor(), --> 121 iterator_resource) 122 return iterator_ops.Iterator(iterator_resource, initializer, 123 self.output_types, self.output_shapes,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_dataset_ops.py in make_iterator(dataset, iterator, name) 2542 if _ctx is None or not _ctx._eager_context.is_eager: 2543 _, _ , _op = _op_def_lib._apply_op_helper( -> 2544 "MakeIterator", dataset=dataset, iterator=iterator, name=name) 2545 return _op 2546 _result = None

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords) 348 # 需要将所有参数展平成一个列表。349 # pylint: disable=protected-access --> 350 g = ops._get_graph_from_inputs(_Flatten(keywords.values())) 351 # pylint: enable=protected-access 352 除了 AssertionError 作为 e:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _get_graph_from_inputs(op_input_list, graph) 5659 graph = graph_element.graph 5660 elif original_graph_element 不是无:-> 5661 _assertoriginal_same_graph( , graph_element) 5662 elif graph_element.graph 不是图:
5663 raise ValueError("%s 不是来自传入的图。" % graph_element)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _assert_same_graph(original_item, item) 5595 如果 original_item.graph 不是 item.graph: 5596 raise ValueError("%s must来自与 %s 相同的图表。” % (item, -> 5597 original_item)) 5598 5599

ValueError: Tensor("IteratorV2:0", shape=(), dtype=resource) 必须与 Tensor("FlatMapDataset:0", shape=(), dtype=variant) 来自同一图。

这是一个错误还是 Keras fit 方法不应该以这种方式使用?

J.E*_*E.K 2

我尝试用一​​个简单的示例重现您的结果,我发现当在生成器函数和 中使用批处理时,您会得到不同的输出形状tf.data

Keras 函数train_datagen.flow_from_directory(batch_size=32)已经返回 shape 的数据[batch_size, width, height, depth]。如果使用tf.data.Dataset().batch(32)输出数据,则会再次将其批处理为 shape [batch_size, batch_size, width, height, depth]

这可能由于某种原因导致了您的问题。