使用 estimator api 避免 tf.data.Dataset.from_tensor_slices

Mar*_*yer 5 python tensorflow

我正在尝试找出将datasetapi 与 api 一起使用的推荐方法estimator。我在网上看到的一切都是这个的一些变体:

def train_input_fn():
   dataset = tf.data.Dataset.from_tensor_slices((features, labels))
   return dataset
Run Code Online (Sandbox Code Playgroud)

然后可以将其传递给估算器的 train 函数:

 classifier.train(
    input_fn=train_input_fn,
    #...
 )
Run Code Online (Sandbox Code Playgroud)

数据集指南警告说:

上面的代码片段会将特征和标签数组作为 tf.constant() 操作嵌入到您的 TensorFlow 图中。这适用于小数据集,但会浪费内存——因为数组的内容将被多次复制——并且可能会遇到 tf.GraphDef 协议缓冲区的 2GB 限制。

然后描述一种方法,该方法涉及定义占位符,然后用 填充feed_dict

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})
Run Code Online (Sandbox Code Playgroud)

但是,如果您使用的是estimatorapi,则不会手动运行会话。那么如何将datasetapi 与 estimators 一起使用,同时避免与 相关的问题from_tensor_slices()

Oli*_*ene 4

To use either initializable or reinitializable iterators, you must create a class that inherits from tf.train.SessionRunHook, which has access to the session at multiple times during training and evaluation steps.

You can then use this new class to initialize the iterator has you would normally do in a classic setting. You simply need to pass this newly created hook to the training/evaluation functions or to the correct train spec.

Here is quick example that you can adapt to your needs :

class IteratorInitializerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None # Will be set in the input_fn

    def after_create_session(self, session, coord):
        # Initialize the iterator with the data feed_dict
        self.iterator_initializer_func(session) 


def get_inputs(X, y):
    iterator_initializer_hook = IteratorInitializerHook()

    def input_fn():
        X_pl = tf.placeholder(X.dtype, X.shape)
        y_pl = tf.placeholder(y.dtype, y.shape)

        dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
        dataset = ...
        ...

        iterator = dataset.make_initializable_iterator()
        next_example, next_label = iterator.get_next()


        iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
                                                                                    feed_dict={X_pl: X, y_pl: y})

        return next_example, next_label

    return input_fn, iterator_initializer_hook

...

train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)

...

estimator.train(input_fn=train_input_fn,
                hooks=[train_iterator_initializer_hook]) # Don't forget to pass the hook !
estimator.evaluate(input_fn=test_input_fn,
                   hooks=[test_iterator_initializer_hook])
Run Code Online (Sandbox Code Playgroud)