Mih*_*kov 4 python tensorflow word-embedding
我有一个embedded_chars
用以下代码创建的数组:
self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name="input_x")
W = tf.Variable(
tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
name="W"
)
self.embedded_chars = tf.nn.embedding_lookup(W, self.input_x)
Run Code Online (Sandbox Code Playgroud)
input_x
如果我只有embedded_chars
and ,我想要得到数组W
。
我怎么才能得到它?
谢谢!
W
您可以使用和中嵌入向量之间的余弦距离embedded_chars
:
# assume embedded_chars.shape == (batch_size, embedding_size)
emb_distances = tf.matmul( # shape == (vocab_size, batch_size)
tf.nn.l2_normalize(W, dim=1),
tf.nn.l2_normalize(embedded_chars, dim=1),
transpose_b=True)
token_ids = tf.argmax(emb_distances, axis=0) # shape == (batch_size)
Run Code Online (Sandbox Code Playgroud)
这emb_distances
是 L2 归一化矩阵W
和的点积,它与 中的所有向量和 中的所有向量transpose(embedded_chars)
之间的余弦距离相同。argmax 只是为我们提供了 的每列中最高值的索引。W
embedded_chars
emb_distances
@Yao Zhang:如果所有嵌入W
都不同(它们应该是这样),那么这将为您提供正确的结果:余弦距离始终在 [-1, 1] 和 cos(vector_a, vector_a) == 1 之间。
请注意,通常您不需要进行这种从嵌入到标记索引的转换:通常您可以直接获取作为第二个参数传递给 的张量的值tf.nn.embedding_embedding_lookup
,它是标记索引的张量。
归档时间: |
|
查看次数: |
4976 次 |
最近记录: |