我现在正在基于提供的示例代码实现seq2seq模型tensorflow.我希望获得前5个解码器输出来进行强化学习.
然而,他们用注意解码器实现了翻译模型,所以我应该实现波束搜索以获得前k个结果.
现在有一部分代码实现(此代码已添加到translate.py).
参考https://github.com/tensorflow/tensorflow/issues/654
with tf.Graph().as_default():
beam_size = FLAGS.beam_size # Number of hypotheses in beam
num_symbols = FLAGS.tar_vocab_size # Output vocabulary size
embedding_size = 10
num_steps = 5
embedding = tf.zeros([num_symbols, embedding_size])
output_projection = None
log_beam_probs, beam_symbols, beam_path = [], [], []
def beam_search(prev, i):
if output_projection is not None:
prev = tf.nn.xw_plus_b(prev, output_projection[0], output_projection[1])
probs = tf.log(tf.nn.softmax(prev))
if i > 1:
probs = tf.reshape(probs + log_beam_probs[-1], [-1, beam_size …Run Code Online (Sandbox Code Playgroud)