我正在使用tensorflow数据集api作为我的训练数据,输入tf.data.Dataset.from_generator api的生成器和生成器
def generator():
......
yield { "x" : features }, label
def input_fn():
ds = tf.data.Dataset.from_generator(generator, ......)
......
feature, label = ds.make_one_shot_iterator().get_next()
return feature, label
Run Code Online (Sandbox Code Playgroud)
然后我为我的Estimator创建了一个自定义的model_fn,代码如下:
def model_fn(features, labels, mode, params):
print(features)
......
layer = network.create_full_connect(input_tensor=features["x"],
(or layer = tf.layers.dense(features["x"], 200, ......)
......
Run Code Online (Sandbox Code Playgroud)
训练时:
estimator.train(input_fn=input_fn)
Run Code Online (Sandbox Code Playgroud)
但是,代码不起作用,因为函数model_fn的features参数是:
Tensor("IteratorGetNext:0",dtype = float32,device =/device:CPU:0)
代码"features ["x"]"将失败告诉我:
......"site-packages\tensorflow\python\ops\array_ops.py",第504行,在_SliceHelper中end.append(s + 1)TypeError:必须是str,而不是int
如果我将input_fn更改为:
input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": np.array([[1,2,3,4,5,6]])},
y=np.array([1]),
Run Code Online (Sandbox Code Playgroud)
代码继续,因为现在的功能是一个字典.
我搜索了估算器的代码,发现它使用了一些函数,如
features, labels = self._get_features_and_labels_from_input_fn(
input_fn, model_fn_lib.ModeKeys.TRAIN)
Run Code Online (Sandbox Code Playgroud)
从input_fn检索功能和标签,但我不知道为什么它通过使用不同的数据集实现通过我(model_fn)两种不同的数据类型的功能,如果我想使用我的生成器模式,那么如何使用该类型(IteratorGetNext )功能?
谢谢你的帮助!
[更新]
我对代码做了一些改动,
def generator():
......
yield features, label
def input_fn():
ds = tf.data.Dataset.from_generator(generator, ......)
......
feature, label = ds.make_one_shot_iterator().get_next()
return {"x": feature}, label
Run Code Online (Sandbox Code Playgroud)
然而,现在它仍然在tf.layers.dense失败了
"图层dense_1的输入0与图层不兼容:其等级未定义,但图层需要定义的等级."
虽然功能是一个字典:
'x':tf.Tensor'IteratorGetNext:0'shape = unknown dtype = float64
在正确的情况下,它是:
'x':tf.Tensor'random_shuffle_queue_DequeueMany:1'shape =(128,6)dtype = float64
我从中学到了类似的用法
https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html
def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):
def decode_csv(line):
......
d = dict(zip(feature_names, features)), label
return d
dataset = (tf.data.TextLineDataset(file_path)
Run Code Online (Sandbox Code Playgroud)
但是没有关于生成器案例的官方示例,它将迭代器返回给自定义的model_fn.
根据如何使用from_generator的示例,生成器返回要放入数据集的值,而不是特征的字典.相反,你建立了dict input_fn.
如下改变代码应该使它工作:
def generator():
......
yield features, label
def input_fn():
ds = tf.data.Dataset.from_generator(generator, ......)
......
feature, label = ds.make_one_shot_iterator().get_next()
return {"x": feature}, label
Run Code Online (Sandbox Code Playgroud)
您的代码失败,因为由a的迭代器生成的张量Dataset.from_generator没有静态shape定义(因为生成器原则上可以返回具有不同形状的数据).假设您的数据确实总是具有相同的形状,您可以(请参阅编辑打击以获取正确的方法).feature.set_shape(<the_shape_of_your_data>)在return开始之前调用input_fn
当您在评论中指出,tf.data.Dataset.from_generator()具有内置的输出张量的形状的第三个参数,所以不是feature.set_shape()仅仅通过形状output_shapes中from_generator().
| 归档时间: |
|
| 查看次数: |
2676 次 |
| 最近记录: |