Tensorflow 2.2.0 错误:[预测必须 > 0] [条件 x >= y 没有按元素保持:] 使用双向 LSTM 层时

Cha*_*aos 7 python lstm tensorflow

在处理命名实体识别任务时,我收到以下错误消息:

tensorflow.python.framework.errors_impl.InvalidArgumentError:  assertion failed: [predictions must be >= 0] [Condition x >= y did not hold element-wise:] [x (bidirectional_lstm_model/time_distributed/Reshape_1:0) = ] [[[-0.100267865 -0.104010895 0.04090859...]]...] [y (Cast_2/x:0) = ] [0]
     [[{{node assert_greater_equal/Assert/AssertGuard/else/_1/Assert}}]] [Op:__inference_train_function_6216]
Function call stack:
train_function
Run Code Online (Sandbox Code Playgroud)

我该如何解决这个问题?我检查了我的输入train_xtrain_y张量,它们看起来很好(最后提供了一些例子)。

我最初使用的是条件随机场解码器。我用一个 Dense 层代替它,看看是否会改变错误消息。错误仍然相同,并且在某种程度上与模型的 RNN 组件有关。

一般来说,您使用什么策略从 TF 的内部深处解决此类错误?我尝试在 PyCharm 上设置调试会话并跳过一堆 TF 文件,但没有学习任何关于如何解决我的问题的有用信息。

以下是我的网络架构:

Model: "bidirectional_lstm_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
encoder_input (InputLayer)   [(None, None)]            0         
_________________________________________________________________
encoder_embedding (Embedding (None, None, 300)         2013300   
_________________________________________________________________
encoder_bidirectional_rnn (B (None, None, 32)          40576     
_________________________________________________________________
time_distributed (TimeDistri (None, None, 25)          825       
=================================================================
Total params: 2,054,701
Trainable params: 41,401
Non-trainable params: 2,013,300
_________________________________________________________________
Run Code Online (Sandbox Code Playgroud)

以上 + 更多细节(损失、优化器等):

# Create model
encoder_input = keras.Input(shape=(None,), name='encoder_input')
encoder_embedding = layers.Embedding(input_dim=input_vocabulary,
                                     output_dim=embedding_vector_len,
                                     embeddings_initializer=tf.keras.initializers.Constant(embedding_matrix),
                                     trainable=False, name='encoder_embedding')(encoder_input)
encoder_rnn = layers.LSTM(16, return_sequences=True, name='encoder_rnn')
encoder_bidirectional_rnn = layers.Bidirectional(encoder_rnn, name='encoder_bidirectional_rnn')(encoder_embedding)
decoder_dense = layers.TimeDistributed(layers.Dense(number_of_tags, name='decoder_dense'))(encoder_bidirectional_rnn)
model = keras.Model(inputs=encoder_input, outputs=decoder_dense, name='bidirectional_lstm_model')
model.summary()

metrics_precision = tf.keras.metrics.Precision()
metrics_recall = tf.keras.metrics.Recall()
model.compile(
    loss=tf.keras.losses.categorical_crossentropy,
    optimizer='adam',
    metrics=[metrics_precision, metrics_recall]
)
Run Code Online (Sandbox Code Playgroud)

这是 mytrain_xtrain_y数组的样子:

# Shapes
train_x.shape  # (9775, 47)  (np.ndarray type)
train_y.shape  # TensorShape([9775, 47, 25])  (Obtained from tf.one_hot)

# Sample (Zero-padded from the right)
train_x[0, :]

# array([4917, 2806, 6357, 2287, 6059,    0,    0,    0,    0,    0,    0,
#      0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
#      0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
#      0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
#      0,    0,    0])

train_y[0, :, :]

# array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],  # Non "O" tag
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],  # Non "O" tag
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
#   [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)
Run Code Online (Sandbox Code Playgroud)

Gui*_*lem 6

你错过了最后一层激活:

decoder_dense = layers.TimeDistributed(layers.Dense(number_of_tags, name='decoder_dense'))(encoder_bidirectional_rnn)
Run Code Online (Sandbox Code Playgroud)

你应该指定你想要一个 softmax,将激活保留为默认值实际上是一个线性激活,这意味着你可以有任何值,因此是负值。您应该按如下方式创建最后一个 Dense 层:

decoder_dense = layers.TimeDistributed(layers.Dense(number_of_tags, activation='softmax', name='decoder_dense'))(encoder_bidirectional_rnn)

Run Code Online (Sandbox Code Playgroud)