相关疑难解决方法(0)

在Keras中具有多个输入/输出的tf.data

对于应用程序,例如配对文本相似性,输入数据类似于:pair_1, pair_2.在这些问题中,我们通常有多个输入数据.以前,我成功实现了我的模型:

model.fit([pair_1, pair_2], labels, epochs=50)
Run Code Online (Sandbox Code Playgroud)

我决定用tf.data API 替换我的输入管道.为此,我创建了一个类似于的数据集:

dataset = tf.data.Dataset.from_tensor_slices((pair_1, pair2, labels))
Run Code Online (Sandbox Code Playgroud)

它成功编译但是当开始训练时会引发以下异常:

AttributeError: 'tuple' object has no attribute 'ndim'
Run Code Online (Sandbox Code Playgroud)

我的Keras和Tensorflow版本分别是2.1.61.11.0.我在Tensorflow存储库中发现了类似的问题: tf.keras多输入模型在使用tf.data.Dataset时不起作用.

有谁知道如何解决这个问题?

以下是代码的一些主要部分:

(q1_test, q2_test, label_test) = test
(q1_train, q2_train, label_train) = train

    def tfdata_generator(sent1, sent2, labels, is_training):
        '''Construct a data generator using tf.Dataset'''

        dataset = tf.data.Dataset.from_tensor_slices((sent1, sent2, labels))
        if is_training:
            dataset = dataset.shuffle(1000)  # depends on sample size

        dataset = dataset.repeat()
        dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)

        return …
Run Code Online (Sandbox Code Playgroud)

keras tensorflow tensorflow-datasets

17
推荐指数
2
解决办法
4908
查看次数

标签 统计

keras ×1

tensorflow ×1

tensorflow-datasets ×1