小编Y. *_*hev的帖子

Tensorflow:如何在Estimator中使用来自生成器的数据集

试图建立简单的模型只是为了弄清楚如何处理tf.data.Dataset.from_generator.我无法理解如何设置output_shapes参数.我尝试了几种组合,包括没有指定它,但由于张量的形状不匹配仍然会收到一些错误.这个想法只是产生两个numpy数组,SIZE = 10并用它们运行线性回归.这是代码:

SIZE = 10


def _generator():
    feats = np.random.normal(0, 1, SIZE)
    labels = np.random.normal(0, 1, SIZE)
    yield feats, labels


def input_func_gen():
    shapes = (SIZE, SIZE)
    dataset = tf.data.Dataset.from_generator(generator=_generator,
                                             output_types=(tf.float32, tf.float32),
                                             output_shapes=shapes)
    dataset = dataset.batch(10)
    dataset = dataset.repeat(20)
    iterator = dataset.make_one_shot_iterator()
    features_tensors, labels = iterator.get_next()
    features = {'x': features_tensors}
    return features, labels


def train():
    x_col = tf.feature_column.numeric_column(key='x', )
    es = tf.estimator.LinearRegressor(feature_columns=[x_col])
    es = es.train(input_fn=input_func_gen)
Run Code Online (Sandbox Code Playgroud)

另一个问题是,是否可以使用此功能为特征列提供数据tf.feature_column.crossed_column?总体目标是Dataset.from_generator在批处理培训中使用功能,在数据不适合内存的情况下,数据从数据库加载到数据块.所有意见和例子都非常感谢.

谢谢!

tensorflow tensorflow-datasets

7
推荐指数
1
解决办法
6562
查看次数

标签 统计

tensorflow ×1

tensorflow-datasets ×1