在Tensorflow中可视化注意力激活

rei*_*ste 10 deep-learning tensorflow attention-model sequence-to-sequence

在此输入图像描述

有没有办法在TensorFlow的seq2seq模型中可视化某些输入上的注意权重,如上面链接中的图(来自Bahdanau等,2014)?我已经找到了TensorFlow的github问题,但我无法找到如何在会话期间获取注意掩码.

小智 6

我还希望将Tensorflow seq2seq ops的注意力量可视化为我的文本摘要任务.我认为临时解决方案是使用session.run()来评估上面提到的注意掩码张量.有趣的是,原始的seq2seq.py操作被认为是遗留版本,并且无法在github中轻松找到,因此我只使用了0.12.0滚轮分发中的seq2seq.py文件并对其进行了修改.为了绘制热图,我使用了'Matplotlib'包,非常方便.

新闻标题文本的关注可视化的最终输出如下所示: 在此输入图像描述

我修改了代码如下:https: //github.com/rockingdingo/deepnlp/tree/master/deepnlp/textsum#attention-visualization

seq2seq_attn.py

# Find the attention mask tensor in function attention_decoder()-> attention()
# Add the attention mask tensor to ‘return’ statement of all the function that calls the attention_decoder(), 
# all the way up to model_with_buckets() function, which is the final function I use for bucket training.

def attention(query):
  """Put attention masks on hidden using hidden_features and query."""
  ds = []  # Results of attention reads will be stored here.

  # some code

  for a in xrange(num_heads):
    with variable_scope.variable_scope("Attention_%d" % a):
      # some code

      s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y),
                              [2, 3])
      # This is the attention mask tensor we want to extract
      a = nn_ops.softmax(s)

      # some code

  # add 'a' to return function
  return ds, a
Run Code Online (Sandbox Code Playgroud)

seq2seq_model_attn.py

# modified model.step() function and return masks tensor
self.outputs, self.losses, self.attn_masks = seq2seq_attn.model_with_buckets(…)

# use session.run() to evaluate attn masks
attn_out = session.run(self.attn_masks[bucket_id], input_feed)
attn_matrix = ...
Run Code Online (Sandbox Code Playgroud)

predict_attn.pyeval.py

# Use the plot_attention function in eval.py to visual the 2D ndarray during prediction.

eval.plot_attention(attn_matrix[0:ty_cut, 0:tx_cut], X_label = X_label, Y_label = Y_label)
Run Code Online (Sandbox Code Playgroud)

并且可能在未来,张量流将有更好的方式来提取和可视化注意力度图.有什么想法吗?