为什么keras张量流中的layers.Embedding需要input_length?

fig*_*uts 4 reshape keras tensorflow tf.keras tensorflow2.0

Layers.embedding 有一个参数(input_length),文档描述为:

input_length :输入序列的长度(当它是常数时)。如果要连接 Flatten 然后是 Dense 层上游(没有它,则无法计算密集输出的形状),则需要此参数。

为什么密集输出的形状无法计算。对我来说,Flatten似乎很容易做到。它tf.rehshape(input,(-1,1))后面只是一个密集层,具有我们选择的任何输出形状。

你能帮我指出我对整个逻辑的理解上的失误吗?

Zab*_*azi 5

通过指定维度,您可以确保模型接收固定长度的输入。

从技术上讲,您可以输入None您想要的任何输入维度。形状将在运行时推断。

您只需要确保指定层参数(input_dim、output_dim)、kernel_size(对于转换层)、单位(对于 FC 层)。

Input如果您使用并指定将通过网络传递的张量的形状,则可以计算形状。

例如,以下模型是完全有效的:

from tensorflow.keras import layers
from tensorflow.keras import models

ip = layers.Input((10))
emb = layers.Embedding(10, 2)(ip)
flat = layers.Flatten()(emb)
out = layers.Dense(5)(flat)

model = models.Model(ip, out)

model.summary()
Run Code Online (Sandbox Code Playgroud)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 10)]              0         
_________________________________________________________________
embedding (Embedding)        (None, 10, 2)             20        
_________________________________________________________________
flatten (Flatten)            (None, 20)                0         
_________________________________________________________________
dense (Dense)                (None, 5)                 105       
=================================================================
Total params: 125
Trainable params: 125
Non-trainable params: 0
Run Code Online (Sandbox Code Playgroud)

在这里,我没有指定 input_length ,但它是从Input层推断出来的。

问题在于 Sequential API,如果您没有在输入层中指定输入形状,也没有在嵌入层中指定输入形状,则无法使用正确的参数集来构建模型。

例如,

from tensorflow.keras import layers
from tensorflow.keras import models

model = models.Sequential()
model.add(layers.Embedding(10, 2, input_length = 10)) # will be an error if I don't specify input_length here as there is no way to know the shape of the next layers without knowing the length

model.add(layers.Flatten())
model.add(layers.Dense(5))


model.summary()
Run Code Online (Sandbox Code Playgroud)

在此示例中,您必须指定 input_length,否则模型将抛出错误。