我如何使用 keras 创建 3d 输入/3d 输出卷积模型?

zer*_*r03 0 3d conv-neural-network keras keras-layer

我有一个我无法解决的问题。

我想将具有完全连接的 MLP 的 CNN 模型实现到我的具有 2589 个蛋白质的蛋白质数据库。每个蛋白质有 1287 行 69 列作为输入和 1287 行和 8 列作为输出。实际上有 1287x1 的输出,但我使用了一种热编码作为类标签,以便在我的模型中使用交叉熵损失。

我也想要

如果我们认为图像我有一个 3d 矩阵 ** X_train = (2589, 1287, 69) for input** 和y_train =(2589, 1287, 8) output,我的意思是输出也是矩阵。

在我的 keras 代码下面:

model = Sequential()
model.add(Conv2D(64, kernel_size=3, activation="relu", input_shape=(X_train.shape[1],X_train.shape[2])))
model.add(Conv2D(32, kernel_size=3, activation="relu"))
model.add(Flatten())
model.add(Dense((8), activation="softmax"))
Run Code Online (Sandbox Code Playgroud)

但是我遇到了关于密集层的错误:

ValueError: Error when checking target: expected dense_1 to have 2 dimensions, but got array with shape (2589, 1287, 8)
Run Code Online (Sandbox Code Playgroud)

好的,我知道 Dense 应该采用正整数单位(Keras 文档中的解释。)。但是我如何实现矩阵输出到我的模型?

我试过了:

model.add(Dense((1287,8), activation="softmax"))
Run Code Online (Sandbox Code Playgroud)

和其他东西,但我找不到任何解决方案。

非常感谢。

Pri*_*usa 5

Conv2D层需要输入形状为(batch_size, height, width, channels)。这意味着每个样本都是一个 3D 数组。

您的实际输入(2589, 1287, 8)意味着每个样本都是形状(1289, 8)- 2D 形状。因此,您应该使用Conv1D而不是Conv2D.

其次,您想要输出(2589, 1287, 8). 由于每个样本都是 2D 形状,因此Flatten()输入没有任何意义-Flatten()会将每个样本的形状减少到 1D,并且您希望每个样本都是 2D。

最后,根据Conv图层的填充,形状可能会根据kernel_size. 由于您要保留 的中间尺寸1287,请使用padding='same'来保持大小相同。

from keras.models import Sequential
from keras.layers import Conv1D, Flatten, Dense
import numpy as np

X_train = np.random.rand(2589, 1287, 69)
y_train = np.random.rand(2589, 1287, 8)


model = Sequential()
model.add(Conv1D(64, 
                 kernel_size=3, 
                 activation="relu", 
                 padding='same',
                 input_shape=(X_train.shape[1],X_train.shape[2])))
model.add(Conv1D(32, 
                 kernel_size=3, 
                 activation="relu",
                 padding='same'))
model.add(Dense((8), activation="softmax"))

model.summary()
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(X_train, y_train)
Run Code Online (Sandbox Code Playgroud)