小编Vla*_*nyk的帖子

将预先训练的word2vec向量注入TensorFlow seq2seq

我试图将预训练的word2vec向量注入现有的tensorflow seq2seq模型.

根据这个答案,我制作了以下代码.但它似乎并没有改善性能,尽管变量中的值已更新.

根据我的理解,错误可能是因为EmbeddingWrapper或embedding_attention_decoder创建了独立于词汇表顺序的嵌入?

将预训练矢量加载到张量流模型中的最佳方法是什么?

SOURCE_EMBEDDING_KEY = "embedding_attention_seq2seq/RNN/EmbeddingWrapper/embedding"
TARGET_EMBEDDING_KEY = "embedding_attention_seq2seq/embedding_attention_decoder/embedding"


def inject_pretrained_word2vec(session, word2vec_path, input_size, dict_dir, source_vocab_size, target_vocab_size):
  word2vec_model = word2vec.load(word2vec_path, encoding="latin-1")
  print("w2v model created!")
  session.run(tf.initialize_all_variables())

  assign_w2v_pretrained_vectors(session, word2vec_model, SOURCE_EMBEDDING_KEY, source_vocab_path, source_vocab_size)
  assign_w2v_pretrained_vectors(session, word2vec_model, TARGET_EMBEDDING_KEY, target_vocab_path, target_vocab_size)


def assign_w2v_pretrained_vectors(session, word2vec_model, embedding_key, vocab_path, vocab_size):
  vectors_variable = [v for v in tf.trainable_variables() if embedding_key in v.name]
  if len(vectors_variable) != 1:
      print("Word vector variable not found or too many. key: " + embedding_key)
      print("Existing embedding trainable variables:")
      print([v.name for v in tf.trainable_variables() …
Run Code Online (Sandbox Code Playgroud)

python word2vec tensorflow

7
推荐指数
1
解决办法
2610
查看次数

标签 统计

python ×1

tensorflow ×1

word2vec ×1