寻找K-最近邻及其实现

You*_*yst 17 matlab classification machine-learning knn

我正在使用具有欧几里德距离的KNN对简单数据进行分类.我已经看到了一个关于我想用MATLAB knnsearch函数完成的例子,如下所示:

load fisheriris 
x = meas(:,3:4);
gscatter(x(:,1),x(:,2),species)
newpoint = [5 1.45];
[n,d] = knnsearch(x,newpoint,'k',10);
line(x(n,1),x(n,2),'color',[.5 .5 .5],'marker','o','linestyle','none','markersize',10)
Run Code Online (Sandbox Code Playgroud)

上面的代码采用了一个新点,即[5 1.45]找到与新点最接近的10个值.任何人都可以给我看一个MATLAB算法,详细解释该knnsearch函数的作用吗?有没有其他方法可以做到这一点?

ray*_*ica 39

K-最近邻(KNN)算法的基础是你有一个由N行和M列组成的数据矩阵,其中N是我们拥有的数据点的数量,而M每个数据点的维度.例如,如果我们将笛卡尔坐标放在数据矩阵中,这通常是一个N x 2或一个N x 3矩阵.使用此数据矩阵,您可以提供查询点,并搜索k此数据矩阵中距离此查询点最近的最近点.

我们通常使用查询与数据矩阵中其余点之间的欧几里德距离来计算距离.但是,也使用其他距离,如L1或City-Block/Manhattan距离.在此操作之后,您将具有N欧几里德或曼哈顿距离,其表示查询与数据集中的每个对应点之间的距离.找到这些后,只需k按升序对距离进行排序,然后检索k数据集与查询之间距离最小的点,即可搜索最近的查询点.

假设您的数据矩阵存储在其中x,并且newpoint是具有M列(即1 x M)的示例点,这是您将以点形式遵循的一般过程:

  1. 找到欧几里得或曼哈顿之间的距离newpoint和每个点x.
  2. 按升序对这些距离进行排序.
  3. 返回最接近的k数据点.xnewpoint

让我们慢慢地做每一步.


步骤1

有人可能会这样做的一种方式可能就是这样一个for循环:

N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
    dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2));
end
Run Code Online (Sandbox Code Playgroud)

如果你想实现曼哈顿距离,那么这只是:

N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
    dists(idx) = sum(abs(x(idx,:) - newpoint));
end
Run Code Online (Sandbox Code Playgroud)

dists将是一个N元素向量,包含每个数据点之间的距离xnewpoint.我们在newpoint数据点和数据点之间进行逐元素减法x,将差异平方,然后将sum它们全部放在一起.然后这个总和是平方根,这完成了欧几里德距离.对于曼哈顿距离,您将逐个元素执行逐元素,取绝对值,然后将所有组件加在一起.这可能是最简单的实现,但它可能是效率最低的...尤其是对于更大的数据集和更大的数据维度.

另一种可能的解决方案是复制newpoint并使该矩阵的大小相同x,然后对该矩阵进行逐个元素的减法,然后对每一行的所有列求和并进行平方根.因此,我们可以这样做:

N = size(x, 1);
dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2));
Run Code Online (Sandbox Code Playgroud)

对于曼哈顿距离,您可以:

N = size(x, 1);
dists = sum(abs(x - repmat(newpoint, N, 1)), 2);
Run Code Online (Sandbox Code Playgroud)

repmat采用矩阵或向量,并在给定方向上重复它们一定次数.在我们的例子中,我们想要获取newpoint向量,并将这些N时间叠加在一起以创建一个N x M矩阵,其中每一行都是M元素长.我们将这两个矩阵一起减去,然后对每个分量进行平方.一旦我们这样做,我们sum遍历每一行的所有列,最后取所有结果的平方根.对于曼哈顿距离,我们进行减法,取绝对值然后求和.

但是,在我看来,最有效的方法是使用bsxfun.这实际上是通过单个函数调用在我们讨论的复制.因此,代码就是这样:

dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
Run Code Online (Sandbox Code Playgroud)

对我来说,这看起来更清洁,更重要.对于曼哈顿距离,您可以:

dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
Run Code Online (Sandbox Code Playgroud)

第2步

现在我们有距离,我们只需对它们进行排序.我们可以sort用来对距离进行排序:

[d,ind] = sort(dists);
Run Code Online (Sandbox Code Playgroud)

d将包含按升序排序的距离,同时ind告诉您未排序数组中每个值在排序结果中出现的位置.我们需要使用ind,提取k该向量的第一个元素,然后使用ind索引到我们的x数据矩阵中以返回最接近的那些点newpoint.

第3步

最后一步是返回k最接近的数据点newpoint.我们可以通过以下方式简单地完成

ind_closest = ind(1:k);
x_closest = x(ind_closest,:);
Run Code Online (Sandbox Code Playgroud)

ind_closest应该包含原始数据矩阵x中最接近的索引newpoint.具体来说,ind_closest包含您需要从哪些进行采样x以获得最接近的点newpoint. x_closest将包含那些实际数据点.


为了您的复制和粘贴乐趣,这就是代码的样子:

dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
%// Or do this for Manhattan
% dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);
Run Code Online (Sandbox Code Playgroud)

通过您的示例,让我们看看我们的代码:

load fisheriris 
x = meas(:,3:4);
newpoint = [5 1.45];
k = 10;

%// Use Euclidean
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);
Run Code Online (Sandbox Code Playgroud)

通过检查ind_closestx_closest,这就是我们得到:

>> ind_closest

ind_closest =

   120
    53
    73
   134
    84
    77
    78
    51
    64
    87

>> x_closest

x_closest =

    5.0000    1.5000
    4.9000    1.5000
    4.9000    1.5000
    5.1000    1.5000
    5.1000    1.6000
    4.8000    1.4000
    5.0000    1.7000
    4.7000    1.4000
    4.7000    1.4000
    4.7000    1.5000
Run Code Online (Sandbox Code Playgroud)

如果你跑了knnsearch,你会看到你的变量n与之匹配ind_closest.然而,该变量d返回的距离,newpoint每一个点x,而不是实际的数据点本身.如果你想要实际距离,只需在我写的代码后执行以下操作:

dist_sorted = d(1:k);
Run Code Online (Sandbox Code Playgroud)

请注意,上面的答案在一批N示例中仅使用一个查询点.KNN非常频繁地同时用于多个示例.假设我们有Q想要在KNN中测试的查询点.这将产生k x M x Q矩阵,其中对于每个示例或每个切片,我们返回k具有维度的最近点M.或者,我们可以返回最近点的ID,k从而得到Q x k矩阵.我们来计算两者.

一个天真的方法是将上面的代码应用于循环并循环遍历每个示例.

这样的东西可以在我们分配Q x k矩阵并应用bsxfun基于方法的情况下将输出矩阵的每一行设置k为数据集中最接近的点,我们将像之前一样使用Fisher Iris数据集.我们也将保持同样的维度就像我们在前面的例子中那样,我将用四个例子,所以Q = 4M = 2:

%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];

%// Define k and the output matrices
Q = size(newpoints, 1);
M = size(x, 2);
k = 10;
x_closest = zeros(k, M, Q);
ind_closest = zeros(Q, k);

%// Loop through each point and do logic as seen above:
for ii = 1 : Q
    %// Get the point
    newpoint = newpoints(ii, :);

    %// Use Euclidean
    dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
    [d,ind] = sort(dists);

    %// New - Output the IDs of the match as well as the points themselves
    ind_closest(ii, :) = ind(1 : k).';
    x_closest(:, :, ii) = x(ind_closest(ii, :), :);
end
Run Code Online (Sandbox Code Playgroud)

虽然这非常好,但我们可以做得更好.有一种方法可以有效地计算两组矢量之间的平方欧几里德距离.如果你想和曼哈顿一起做这件事,我会把它留作练习.咨询该博客,假定A是一个Q1 x M矩阵,其中每一行是维数的一个点MQ1分和B是一个Q2 x M矩阵,其中每行是也维的点MQ2点,我们可以高效地计算距离矩阵D(i, j),其中在连续元件i和列j表示行之间的距离iA和行jB使用下面的基质制剂:

nA = sum(A.^2, 2); %// Sum of squares for each row of A
nB = sum(B.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*A*B.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation
Run Code Online (Sandbox Code Playgroud)

因此,如果我们A设为查询点矩阵并且B是由原始数据组成的数据集,我们可以k通过单独排序每一行并确定k每行最小的位置来确定最近的点.我们还可以另外使用它来检索实际的点本身.

因此:

%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];

%// Define k and other variables
k = 10;
Q = size(newpoints, 1);
M = size(x, 2);

nA = sum(newpoints.^2, 2); %// Sum of squares for each row of A
nB = sum(x.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*newpoints*x.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation 

%// Sort the distances 
[d, ind] = sort(D, 2);

%// Get the indices of the closest distances
ind_closest = ind(:, 1:k);

%// Also get the nearest points
x_closest = permute(reshape(x(ind_closest(:), :).', M, k, []), [2 1 3]);
Run Code Online (Sandbox Code Playgroud)

我们看到我们使用逻辑来计算距离矩阵是相同的但是一些变量已经改变以适应这个例子.我们还使用两个输入版本独立地对每一行进行排序sort,因此ind将包含每行的ID d并将包含相应的距离.然后我们通过简单地将此矩阵截断为k列来确定哪些索引最接近每个查询点.然后我们使用permutereshape确定相关的最近点是什么.我们首先使用所有最接近的索引并创建一个点矩阵,将所有ID叠加在一起,这样我们得到一个Q * k x M矩阵.使用reshapepermute允许我们创建我们的3D矩阵,使其成为k x M x Q我们指定的矩阵.如果你想自己获得实际距离,我们可以索引d并抓住我们需要的东西.为此,您需要使用sub2ind获取线性索引,以便我们可以d一次性索引.值ind_closest已经为我们提供了我们需要访问的列.我们需要访问的行只有1 k次,2次,2 k次等等Q. k是我们想要返回的点数:

row_indices = repmat((1:Q).', 1, k);
linear_ind = sub2ind(size(d), row_indices, ind_closest);
dist_sorted = D(linear_ind);
Run Code Online (Sandbox Code Playgroud)

当我们为上述查询点运行上面的代码时,这些是我们获得的索引,点和距离:

>> ind_closest

ind_closest =

   120   134    53    73    84    77    78    51    64    87
   123   119   118   106   132   108   131   136   126   110
   107    62    86   122    71   127   139   115    60    52
    99    65    58    94    60    61    80    44    54    72

>> x_closest

x_closest(:,:,1) =

    5.0000    1.5000
    6.7000    2.0000
    4.5000    1.7000
    3.0000    1.1000
    5.1000    1.5000
    6.9000    2.3000
    4.2000    1.5000
    3.6000    1.3000
    4.9000    1.5000
    6.7000    2.2000


x_closest(:,:,2) =

    4.5000    1.6000
    3.3000    1.0000
    4.9000    1.5000
    6.6000    2.1000
    4.9000    2.0000
    3.3000    1.0000
    5.1000    1.6000
    6.4000    2.0000
    4.8000    1.8000
    3.9000    1.4000


x_closest(:,:,3) =

    4.8000    1.4000
    6.3000    1.8000
    4.8000    1.8000
    3.5000    1.0000
    5.0000    1.7000
    6.1000    1.9000
    4.8000    1.8000
    3.5000    1.0000
    4.7000    1.4000
    6.1000    2.3000


x_closest(:,:,4) =

    5.1000    2.4000
    1.6000    0.6000
    4.7000    1.4000
    6.0000    1.8000
    3.9000    1.4000
    4.0000    1.3000
    4.7000    1.5000
    6.1000    2.5000
    4.5000    1.5000
    4.0000    1.3000

>> dist_sorted

dist_sorted =

    0.0500    0.1118    0.1118    0.1118    0.1803    0.2062    0.2500    0.3041    0.3041    0.3041
    0.3000    0.3162    0.3606    0.4123    0.6000    0.7280    0.9055    0.9487    1.0198    1.0296
    0.9434    1.0198    1.0296    1.0296    1.0630    1.0630    1.0630    1.1045    1.1045    1.1180
    2.6000    2.7203    2.8178    2.8178    2.8320    2.9155    2.9155    2.9275    2.9732    2.9732
Run Code Online (Sandbox Code Playgroud)

要与之进行比较knnsearch,您需要为第二个参数指定一个点矩阵,其中每一行都是一个查询点,您将看到索引和已排序的距离在此实现与之间匹配knnsearch.


希望这对你有所帮助.祝好运!