pytorch中计算角度

Hes*_*sam 4 python pytorch

如果我们有一组点Rs,我们可以使用 torch.cdist 来获取所有点对的距离。

dists_ij = torch.cdist(Rs, Rs)
Run Code Online (Sandbox Code Playgroud)

是否有一个函数可以获取两组向量之间的角度,Vs如下所示:

angs_ij = torch.angs(Vs, Vs)
Run Code Online (Sandbox Code Playgroud)

Sha*_*hai 5

您可以使用两个向量的点积与它们之间的角度之间的关系手动执行此操作:

# normalize the vectors
nVs = Vs / torch.norm(Vs, p=2, dim=-1, keepdim=True)
# compute cosine of the angles using dot product
cos_ij = torch.einsum('bni,bmi->bnm', nVs, nVs)
Run Code Online (Sandbox Code Playgroud)