如何对tf.nn.embedding_lookup进行反向操作?

Mih*_*kov 4 python tensorflow word-embedding

我有一个embedded_chars用以下代码创建的数组:

self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name="input_x")

W = tf.Variable( 
    tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
    name="W"
    )
self.embedded_chars = tf.nn.embedding_lookup(W, self.input_x) 
Run Code Online (Sandbox Code Playgroud)

input_x如果我只有embedded_charsand ,我想要得到数组W

我怎么才能得到它?

谢谢!

Mat*_*rro 5

W您可以使用和中嵌入向量之间的余弦距离embedded_chars

# assume embedded_chars.shape == (batch_size, embedding_size)
emb_distances = tf.matmul( # shape == (vocab_size, batch_size)
    tf.nn.l2_normalize(W, dim=1),
    tf.nn.l2_normalize(embedded_chars, dim=1),
    transpose_b=True)
token_ids = tf.argmax(emb_distances, axis=0) # shape == (batch_size)
Run Code Online (Sandbox Code Playgroud)

emb_distances是 L2 归一化矩阵W和的点积,它与 中的所有向量和 中的所有向量transpose(embedded_chars)之间的余弦距离相同。argmax 只是为我们提供了 的每列中最高值的索引。Wembedded_charsemb_distances

@Yao Zhang:如果所有嵌入W都不同(它们应该是这样),那么这将为您提供正确的结果:余弦距离始终在 [-1, 1] 和 cos(vector_a, vector_a) == 1 之间。

请注意,通常您不需要进行这种从嵌入到标记索引的转换:通常您可以直接获取作为第二个参数传递给 的张量的值tf.nn.embedding_embedding_lookup,它是标记索引的张量。