1 python tensorflow tensorflow-datasets tensorflow-estimator
我想用tf.estimator.Estimator训练我的模式并通过Dataset API加载我的数据.因为我的数据,例如'mnist',是一个数组(张量),所以我尝试用'tf.data加载它. Dataset.from_tensor_slices'.But我不如何将"input_fn"内初始化"make_initializable_iterator".
如果我可以使用'make_one_shot_iterator'成功训练,但在训练前它会加载缓慢.而" TensorFlow中的高级API "是'input_fn'中'make_initializable_iterator'的一个很好的例子,但它需要从'input_fn'向其他函数返回'iterator_initializer_hook'.我想知道还有其他更好或更优雅的方式吗?
def input_fn():
mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
images = mnist_data.train.images.reshape([-1, 28, 28, 1])
labels = np.asarray(mnist_data.train.labels, dtype=np.int64)
# Build dataset iterator
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.repeat(None) # Infinite iterations
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(100)
iterator = dataset.make_one_shot_iterator()
next_example = iterator.get_next()
# Set runhook to initialize iterator
return next_example
Run Code Online (Sandbox Code Playgroud)
在TensorFlow版本1.5及更高版本中,tf.estimator.Estimator当您tf.data.Dataset从中返回a时,它将自动创建并初始化可初始化的迭代器input_fn.这使您可以编写以下代码,而无需担心初始化或挂钩:
def input_fn():
mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
images = mnist_data.train.images.reshape([-1, 28, 28, 1])
labels = np.asarray(mnist_data.train.labels, dtype=np.int64)
# Build dataset.
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.repeat(None) # Infinite iterations
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(100)
return dataset
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2400 次 |
| 最近记录: |