使用不同形状的张量计算余弦距离

Joe*_*Joe 3 cosine-similarity pytorch

我有以下代表词向量的张量

A = (2, 500)
Run Code Online (Sandbox Code Playgroud)

其中第一个维度是BATCH维度(即A包含两个词向量,每个词向量有500个元素)

我还有以下张量

B = (10, 500)
Run Code Online (Sandbox Code Playgroud)

我想计算 A 和 B 之间的余弦距离,这样我得到

C = (2, 10, 1)
Run Code Online (Sandbox Code Playgroud)

即对于 A 中的每一行计算与 B 中每一行的余弦距离

我查看了使用torch.nn.functional.F.cosine_similarity,但这不起作用,因为尺寸必须相同。

在 pytorch 中实现这一目标的最有效方法是什么?

toz*_*CSS 8

接受的解决方案似乎效率低下——它在我的机器上花了很长时间,最终由于内存不足而崩溃了内核——而这个解决方案花了几毫秒:

import torch.nn.functional as F

# cosine similarity = normalize the vectors & multiply
C = F.normalize(A) @ F.normalize(B).t()
Run Code Online (Sandbox Code Playgroud)

这是句子转换器中的实现