TensorFlow 使用来自具有多个输出的生成器的数据集进行拟合:无法正确定义形状?

Dav*_*zer 3 python machine-learning keras tensorflow

我正在尝试使用生成器将项目转换为具有多个输出的单个网络,但我似乎无法弄清楚如何在使用生成器时使多个输出正常运行。这是一段最低限度可验证的代码:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models

def generate_sample():
    x = list("123456789")
    y = list("2345")
    while 1:
        yield np.array(x).astype(np.float32),[np.array(y).astype(np.float32),np.array(y).astype(np.float32)]

dataset = tf.data.Dataset.from_generator(generate_sample,
            output_signature=(
                 tf.TensorSpec(shape=(9,), dtype=tf.float32),
                 tf.TensorSpec(shape=(2,4), dtype=tf.float32)

            ))

dataset = dataset.batch(batch_size=32)

inputs = keras.Input(shape=(next(generate_sample())[0].shape))
x = layers.Dense(512, activation = "relu")(inputs)
x_outputs = layers.Dense(4, activation="relu", name="output")(x)
y_outputs = layers.Dense(4, activation="relu", name="output2")(x)

model = keras.Model(inputs=inputs, outputs=[x_outputs,y_outputs])
model.compile(loss="mse", optimizer = "adam", metrics=['accuracy'])
history = model.fit(dataset, epochs=1, steps_per_epoch=10, validation_data=dataset, validation_steps=5)
Run Code Online (Sandbox Code Playgroud)

这会导致一个很长的错误,其最后部分是:

InvalidArgumentError:不兼容的形状:[32,2,4] vs. [32,4]
[[节点mean_squared_error/SquaredDifference(定义于:1)]] [Op:__inference_train_function_8957]

函数调用栈:train_function

我已经尝试过使用output_shapeoutput_signature等等,以我能想象到的各种方式重塑数据。不管怎样,我仍然会遇到体形问题。

fit我在这里遗漏了一些明显的东西,还是使用生成器作为数据集源有问题?例如,当我从内存加载数据时,这样做没有问题。

Les*_*rel 7

模型的输出不是一个形状张量(2,4),而是两个形状张量(4)

您应该更改生成器函数以反映这一点:

def generate_sample():
    x = list("123456789")
    y = list("2345")
    while 1:
        yield np.array(x).astype(np.float32),(np.array(y).astype(np.float32),np.array(y).astype(np.float32))
Run Code Online (Sandbox Code Playgroud)

以及您的输出签名:

dataset = tf.data.Dataset.from_generator(generate_sample,
            output_signature=(
                 tf.TensorSpec(shape=(9,), dtype=tf.float32),
                 (tf.TensorSpec(shape=(4,), dtype=tf.float32),
                 tf.TensorSpec(shape=(4,), dtype=tf.float32)),
            ))
Run Code Online (Sandbox Code Playgroud)

请注意,生成器的输出是一个嵌套元组。

  • 天哪...真的那么容易吗?谢谢。 (3认同)