Numpy Broadcast执行欧几里德距离矢量化

use*_*351 8 python numpy machine-learning vectorization

我有2 x 4和3 x 4的矩阵.我想找到跨行的欧几里德距离,最后得到一个2 x 3矩阵.这是带有一个for循环的代码,它计算针对所有b行向量中每个行向量的欧氏距离.如何在不使用for循环的情况下执行相同的操作?

 import numpy as np
a = np.array([[1,1,1,1],[2,2,2,2]])
b = np.array([[1,2,3,4],[1,1,1,1],[1,2,1,9]])
dists = np.zeros((2, 3))
for i in range(2):
      dists[i] = np.sqrt(np.sum(np.square(a[i] - b), axis=1))
Run Code Online (Sandbox Code Playgroud)

sta*_*010 24

以下是原始输入变量:

A = np.array([[1,1,1,1],[2,2,2,2]])
B = np.array([[1,2,3,4],[1,1,1,1],[1,2,1,9]])
A
# array([[1, 1, 1, 1],
#        [2, 2, 2, 2]])
B
# array([[1, 2, 3, 4],
#        [1, 1, 1, 1],
#        [1, 2, 1, 9]])
Run Code Online (Sandbox Code Playgroud)

A是2x4阵列.B是3x4阵列.

我们想要在一个完全向量化的操作中计算欧几里德距离矩阵运算,其中dist[i,j]包含A中第i个实例和B中第j个实例之间的距离.因此dist在这个例子中是2x3.

距离

在此输入图像描述

表面上可能是用numpy写的

dist = np.sqrt(np.sum(np.square(A-B))) # DOES NOT WORK
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
# ValueError: operands could not be broadcast together with shapes (2,4) (3,4)
Run Code Online (Sandbox Code Playgroud)

然而,如上所示,问题在于逐元素减法操作A-B涉及不兼容的阵列大小,特别是第一维中的2和3.

A has dimensions 2 x 4
B has dimensions 3 x 4
Run Code Online (Sandbox Code Playgroud)

为了进行逐元素减法,我们必须填充A或B以满足numpy的广播规则.我将选择使用额外尺寸填充A,使其变为2 x 1 x 4,这样可以使阵列的尺寸与广播对齐.有关numpy广播的更多信息,请参阅scipy手册中教程本教程中的最后一个示例.

您可以使用np.newaxis值或使用np.reshape命令执行填充.我在下面显示:

# First approach is to add the extra dimension to A with np.newaxis
A[:,np.newaxis,:] has dimensions 2 x 1 x 4
B has dimensions                     3 x 4

# Second approach is to reshape A with np.reshape
np.reshape(A, (2,1,4)) has dimensions 2 x 1 x 4
B has dimensions                          3 x 4
Run Code Online (Sandbox Code Playgroud)

如您所见,使用任一种方法都可以使尺寸对齐.我将使用第一种方法np.newaxis.所以现在,这将创建AB,这是一个2x3x4数组:

diff = A[:,np.newaxis,:] - B
# Alternative approach:
# diff = np.reshape(A, (2,1,4)) - B
diff.shape
# (2, 3, 4)
Run Code Online (Sandbox Code Playgroud)

现在我们可以将差异表达式放入dist方程式语句中以获得最终结果:

dist = np.sqrt(np.sum(np.square(A[:,np.newaxis,:] - B), axis=2))
dist
# array([[ 3.74165739,  0.        ,  8.06225775],
#        [ 2.44948974,  2.        ,  7.14142843]])
Run Code Online (Sandbox Code Playgroud)

注意,sum结束axis=2,这意味着取2x3x4数组的第三个轴(轴id以0开头)的总和.

如果你的数组很小,那么上面的命令就可以了.但是,如果您有大型阵列,那么您可能会遇到内存问题.请注意,在上面的示例中,numpy在内部创建了一个2x3x4阵列来执行广播.如果我们将A概括为具有维度a x z而B具有维度b x z,则numpy将在内部创建a x b x z用于广播的数组.

我们可以通过做一些数学操作来避免创建这个中间数组.因为您将欧几里德距离计算为平方差之和,我们可以利用可以重写平方差之和的数学事实.

在此输入图像描述

请注意,中间项涉及元素乘法的总和.这种倍数的总和更好地称为点积.因为A和B都是矩阵,所以这个操作实际上是一个矩阵乘法.因此,我们可以将上述内容重写为:

在此输入图像描述

然后我们可以编写以下numpy代码:

threeSums = np.sum(np.square(A)[:,np.newaxis,:], axis=2) - 2 * A.dot(B.T) + np.sum(np.square(B), axis=1)
dist = np.sqrt(threeSums)
dist
# array([[ 3.74165739,  0.        ,  8.06225775],
#        [ 2.44948974,  2.        ,  7.14142843]])
Run Code Online (Sandbox Code Playgroud)

请注意,上面的答案与之前的实现完全相同.同样,这里的优点是我们不需要为广播创建中间2x3x4阵列.

为了完整起见,让我们仔细检查threeSums允许广播中每个加数的尺寸.

np.sum(np.square(A)[:,np.newaxis,:], axis=2) has dimensions 2 x 1
2 * A.dot(B.T) has dimensions                               2 x 3
np.sum(np.square(B), axis=1) has dimensions                 1 x 3
Run Code Online (Sandbox Code Playgroud)

因此,正如预期的那样,最终dist数组的尺寸为2x3.

本教程中还讨论了使用点积来代替元素乘法的总和.

  • 这个答案非常有用,尤其是克服广播问题的部分。谢谢@stackoverflowuser2010 (2认同)

Han*_*Qiu 21

最近我在深度学习中遇到了同样的问题(stanford cs231n,Assignment1),但是当我使用时

 np.sqrt((np.square(a[:,np.newaxis]-b).sum(axis=2)))
Run Code Online (Sandbox Code Playgroud)

有一个错误

MemoryError
Run Code Online (Sandbox Code Playgroud)

这意味着我的内存不足(实际上,它在中间生成了一个500*5000*1024的数组.它太大了!)

为了防止该错误,我们可以使用公式来简化:

码:

import numpy as np
aSumSquare = np.sum(np.square(a),axis=1);
bSumSquare = np.sum(np.square(b),axis=1);
mul = np.dot(a,b.T);
dists = np.sqrt(aSumSquare[:,np.newaxis]+bSumSquare-2*mul)
Run Code Online (Sandbox Code Playgroud)

  • 添加一些内容;引自[官方文档](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)`然而,在某些情况下广播是一个坏主意,因为它会导致然而,在某些情况下,广播是一个坏主意,因为它会导致内存使用效率低下,从而减慢计算速度。 (2认同)

gg3*_*349 12

只需np.newaxis在正确的地方使用:

 np.sqrt((np.square(a[:,np.newaxis]-b).sum(axis=2)))
Run Code Online (Sandbox Code Playgroud)

  • 你能解释一下`在正确的地方使用np.newaxis'是如何工作的吗?如果你可以从'a`是2x4并且`b`是3x4这一事实开始,这将是伟大的. (7认同)