TensorFlow 数据集 API 解析错误

aht*_*aht 0 python tensorflow tensorflow-datasets

我正在使用 TensorFlow Dataset API 来解析 CSV 文件并运行逻辑回归。我下面从TF文件的例子在这里

以下代码片段显示了我如何设置模型:

def input_fn(path, num_epochs, batch_size):
    dataset = tf.data.TextLineDataset(path)
    dataset = dataset.map(parse_table, num_parallel_calls=12)
    dataset = dataset.repeat(num_epochs)
    dataset.batch(batch_size)

    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    return features, labels

def parse_table(value):
    cols = tf.decode_csv(value, record_defaults=TAB_COLUMN_DEFAULTS)
    indep_vars = dict(zip(CSV_COLS, cols))
    y = indep_vars.pop('y')
    return indep_vars, y

def build_indep_vars():
    continuous_vars = [
        tf.feature_column.numeric_column(x, shape=1) for x in CONT_COLS]
    categorical_vars = [
        tf.feature_column.categorical_column_with_hash_bucket(
            x, hash_bucket_size=100) for x in CAT_COLS]
    return categorical_vars + continuous_vars
Run Code Online (Sandbox Code Playgroud)

调用时lr.train(input_fn = lambda: input_fn(data_path, 1, 100))(注意:批量大小为 100)我收到错误

ValueError: Feature (key: V1) cannot have rank 0. Give: Tensor("IteratorGetNext:0", shape=(), dtype=float32, device=/device:CPU:0)
Run Code Online (Sandbox Code Playgroud)

所以我假设这意味着其中一个tf.feature_column.numeric_column调用正在获得它不喜欢的标量值。但是,我无法弄清楚为什么会这样。我已经设置batch_size为一个正整数,根据文档,默认情况下产生的 NDarray 的形状tf.feature_column.numeric_column应该是1Xbatch_size。谁能解释为什么 TensorFlow 会返回这个错误?

我敢肯定这个问题有一个简单的答案,会让我觉得自己没有弄清楚它很愚蠢,但是在花了一些时间之后,我仍然被难住了。

mrr*_*rry 5

引发错误是因为这些tf.feature_column方法期望输入被批处理,我认为原因是一个简单的错字,即放弃了Dataset.batch()转换。将 替换为dataset.batch(batch_size)以下行:

dataset = dataset.batch(batch_size)
Run Code Online (Sandbox Code Playgroud)

调用任何tf.data.Dataset转换方法(例如Dataset.map(), Dataset.repeat(), Dataset.batch())不会修改您调用这些方法的对象。相反,这些方法返回一个 Dataset对象,您可以使用该对象进行进一步的转换,或者创建一个Iterator.