如何正确组合TensorFlow的数据集API和Keras?

Jas*_*son 43 keras tensorflow

Keras的fit_generator()模型方法需要一个生成形状元组(输入,目标)的生成器,其中两个元素都是NumPy数组.文档似乎暗示如果我只是将Dataset迭代器包装在生成器中,并确保将Tensors转换为NumPy数组,我应该好好去.但是,这段代码给了我一个错误:

import numpy as np
import os
import keras.backend as K
from keras.layers import Dense, Input
from keras.models import Model
import tensorflow as tf
from tensorflow.contrib.data import Dataset

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

with tf.Session() as sess:
    def create_data_generator():
        dat1 = np.arange(4).reshape(-1, 1)
        ds1 = Dataset.from_tensor_slices(dat1).repeat()

        dat2 = np.arange(5, 9).reshape(-1, 1)
        ds2 = Dataset.from_tensor_slices(dat2).repeat()

        ds = Dataset.zip((ds1, ds2)).batch(4)
        iterator = ds.make_one_shot_iterator()
        while True:
            next_val = iterator.get_next()
            yield sess.run(next_val)

datagen = create_data_generator()

input_vals = Input(shape=(1,))
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mean_squared_error')
model.fit_generator(datagen, steps_per_epoch=1, epochs=5,
                    verbose=2, max_queue_size=2)
Run Code Online (Sandbox Code Playgroud)

这是我得到的错误:

Using TensorFlow backend.
Epoch 1/5
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__
    fetch, allow_tensor=True, allow_operation=True))
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2787, in _as_graph_element_locked
    raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 568, in data_generator_task
    generator_output = next(self._generator)
  File "./datagen_test.py", line 25, in create_data_generator
    yield sess.run(next_val)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 895, in run
    run_metadata_ptr)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1109, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 413, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 233, in for_fetch
    return _ListFetchMapper(fetch)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in <listcomp>
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 241, in for_fetch
    return _ElementFetchMapper(fetches, contraction_fn)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 277, in __init__
    'Tensor. (%s)' % (fetch, str(e)))
ValueError: Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=(?, 1) dtype=int64> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.)

Traceback (most recent call last):
  File "./datagen_test.py", line 34, in <module>
    verbose=2, max_queue_size=2)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper
    return func(*args, **kwargs)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 2011, in fit_generator
    generator_output = next(output_generator)
StopIteration
Run Code Online (Sandbox Code Playgroud)

奇怪的是,next(datagen)在我初始化之后直接添加一行包含datagen导致代码运行正常,没有错误.

为什么我的原始代码不起作用?当我将该行添加到我的代码中时,为什么它开始工作?是否有更有效的方法将TensorFlow的数据集API与Keras一起使用,而不涉及将Tensors转换为NumPy阵列并再次返回?

Dat*_*yen 50

2018年6月9日更新

  • 从Tensorflow 1.9开始,可以tf.data.Dataset直接将对象传递给keras.Model.fit()它,它的行为类似于fit_generator.
  • 在这个要点上可以找到一个完整的例子.
# Load mnist training data
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
training_set = tfdata_generator(x_train, y_train,is_training=True)

model = # your keras model here              
model.fit(
    training_set.make_one_shot_iterator(),
    steps_per_epoch=len(x_train) // 128,
    epochs=5,
    verbose = 1)
Run Code Online (Sandbox Code Playgroud)
  • tfdata_generator是一个返回可迭代的函数tf.data.Dataset.
def tfdata_generator(images, labels, is_training, batch_size=128):
  '''Construct a data generator using `tf.Dataset`. '''

  def map_fn(image, label):
      '''Preprocess raw data to trainable input. '''
    x = tf.reshape(tf.cast(image, tf.float32), (28, 28, 1))
    y = tf.one_hot(tf.cast(label, tf.uint8), _NUM_CLASSES)
    return x, y

  dataset = tf.data.Dataset.from_tensor_slices((images, labels))

  if is_training:
    dataset = dataset.shuffle(1000)  # depends on sample size
  dataset = dataset.map(map_fn)
  dataset = dataset.batch(batch_size)
  dataset = dataset.repeat()
  dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)

  return dataset
Run Code Online (Sandbox Code Playgroud)

旧解决方案:

除了@Yu-Yang的回答之外,您还可以修改tf.data.Dataset成为fit_generator以下的生成器

from tensorflow.contrib.learn.python.learn.datasets import mnist

data   = mnist.load_mnist()
model  = # your Keras model
model.fit_generator(generator = tfdata_generator(data.train.images, data.train.labels),
                    steps_per_epoch=200,
                    workers = 0 , # This is important
                    verbose = 1)


def tfdata_generator(images, labels, batch_size=128, shuffle=True,):
    def map_func(image, label):
        '''A transformation function'''
        x_train = tf.reshape(tf.cast(image, tf.float32), image_shape)
        y_train = tf.one_hot(tf.cast(label, tf.uint8), num_classes)
        return [x_train, y_train]

    dataset  = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset  = dataset.map(map_func)
    dataset  = dataset.shuffle().batch(batch_size).repeat()
    iterator = dataset.make_one_shot_iterator()

    next_batch = iterator.get_next()
    while True:
        yield K.get_session().run(next_batch)
Run Code Online (Sandbox Code Playgroud)

  • 工人= 0线是非常重要的.基本上,如果worker> 0,您最终会遇到多线程问题,因为多个线程正在尝试评估同一个生成器.如果你调用一次生成器(初始化它),它将起作用,因为你已经创建了终点但是你可能会得到奇怪的结果,因为它不是线程安全的 (3认同)
  • 这是AFAIK使用fit_generator的validation_data参数向keras提供验证数据的唯一方法 (2认同)
  • `K.get_session().run(next_batch)`的结果将是numpy数组的列表,不是吗?我认为这个想法是为了避免回到python层并保持张量流... (2认同)

Yu-*_*ang 35

确实有一种更有效的方法,Dataset无需将张量转换为numpy数组.但是,官方文档上还没有(还是?).从发行说明中,它是Keras 2.0.7中引入的一项功能.您可能必须安装keras> = 2.0.7才能使用它.

x = np.arange(4).reshape(-1, 1).astype('float32')
ds_x = Dataset.from_tensor_slices(x).repeat().batch(4)
it_x = ds_x.make_one_shot_iterator()

y = np.arange(5, 9).reshape(-1, 1).astype('float32')
ds_y = Dataset.from_tensor_slices(y).repeat().batch(4)
it_y = ds_y.make_one_shot_iterator()

input_vals = Input(tensor=it_x.get_next())
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mse', target_tensors=[it_y.get_next()])
model.fit(steps_per_epoch=1, epochs=5, verbose=2)
Run Code Online (Sandbox Code Playgroud)

几个区别:

  1. tensor参数提供给Input图层.Keras将从此张量读取值,并将其用作适合模型的输入.
  2. 提供target_tensors参数Model.compile().
  3. 记得将x和y都转换成float32.在正常使用情况下,Keras会为您完成此转换.但现在你必须自己做.
  4. 批量大小在构造期间指定Dataset.使用steps_per_epochepochs控制何时停止模型拟合.

总之,使用Input(tensor=...),model.compile(target_tensors=...)并且model.fit(x=None, y=None, ...)如果你的数据是从张量读取.

  • 看起来甚至不需要有两个独立的迭代器.您可以压缩两个数据集,创建一个类似`next_val = it.get_next()`的节点,并将其输出元素提供给`Input()`和`Model.compile()`函数. (8认同)
  • 迭代器初始化怎么样?我能以某种方式告诉keras用每个时代初始化它吗?或者我仍然需要创建会话并手动完成,然后每次只运行一个纪元? (4认同)