如何测试使用教师强制训练的模型

nig*_*ury 1 nlp lstm recurrent-neural-network seq2seq

我使用 keras 来训练 seq2seq 模型(keras.models.Model)。模型的 X 和 y 是 [X_encoder, X_decoder] , y 即编码器和解码器输入和标签的列表(请注意,解码器输入 X_decoder 是 \xe2\x80\x98y\xe2\x80\x99 ,前面有一个位置比实际的 y 基本上是老师强迫的)。

\n\n

所以我现在的问题是在训练之后,当涉及到实际预测时,我没有任何标签,如何为我的输入提供 \xe2\x80\x98X_decoder\xe2\x80\x99 ?还是我要训练别的东西?

\n\n

这是模型定义的片段(如果有帮助的话):)

\n\n
# Encoder\nencoder_inputs = Input(batch_shape=(batch_size, max_len,), dtype='int32')\nencoder_embedding = embedding_layer(encoder_inputs)\nencoder_LSTM = CuDNNLSTM(hidden_dim, return_state=True, stateful=True)\nencoder_outputs, state_h, state_c = encoder_LSTM(encoder_embedding)\n\n# Decoder\ndecoder_inputs = Input(shape=(max_len,), dtype='int32')\ndecoder_embedding = embedding_layer(decoder_inputs)\ndecoder_LSTM = CuDNNLSTM(hidden_dim, return_state=True, return_sequences=True)\ndecoder_outputs, _, _ = decoder_LSTM(decoder_embedding, initial_state=[state_h, state_c])\n\n# Output\noutputs = TimeDistributed(Dense(vocab_size, activation='softmax'))(decoder_outputs)\nmodel = Model([encoder_inputs, decoder_inputs], outputs)\n\n# model fitting:\nmodel.fit([X_encoder, X_decoder], y, steps_per_epoch=int(number_of_train_samples/batch_size),\nepochs=epochs)\n
Run Code Online (Sandbox Code Playgroud)\n

Dav*_*ale 5

通常,当你训练 seq2seq 模型时,decoder_inputs 的第一个 token 是一个特殊的<start>token。所以当你尝试生成一个句子时,你会这样做

first_token = decoder(encoder_state, [<start>])
second_token = decoder(encoder_state, [<start>, first_token])
third_token = decoder(encoder_state, [<start>, first_token, second_token])
...
Run Code Online (Sandbox Code Playgroud)

您执行此递归,直到您的解码器生成另一个特殊标记 - <end>;然后你停下来。

这是一个适合您的模型的非常粗糙的 pythonic 解码器。它效率低下,因为它一遍又一遍地读取输入,而不是记住 RNN 状态 - 但它有效。

input_seq = # some array of token indices
result = np.array([[START_TOKEN]])
next_token = -1
for i in range(100500):
    next_token = model.predict([input_seq, result])[0][-1].argmax()
    if next_token == END_TOKEN:
        break
    result = np.concatenate([result, [[next_token]]], axis=1)
output_seq = result[0][1:] # omit the first INPUT_TOKEN
Run Code Online (Sandbox Code Playgroud)

更有效的解决方案是将 RNN 状态与每个标记一起输出,并使用它来生成下一个标记。