Sum矩阵元素按Python中的索引分组

Mor*_*fix 5 python numpy sum matrix indices

我有两个矩阵(相同的行和列):一个具有浮点值,它们按另一个矩阵中的索引进行分组.结果,我想要一个字典或列表,其中包含每个索引的元素总和.指数始终从0开始.

A = np.array([[0.52,0.25,-0.45,0.13],[-0.14,-0.41,0.31,-0.41]])
B = np.array([[1,3,1,2],[3,0,2,2]])

RESULT = {0: -0.41, 1: 0.07, 2: 0.03, 3: 0.11}
Run Code Online (Sandbox Code Playgroud)

我找到了这个解决方案,但我正在寻找一个更快的解决方案.我正在使用784 x 300单元的矩阵,这个算法需要大约28ms才能完成.

import numpy as np

def matrix_sum_by_indices(indices,matrix):
    a = np.hstack(indices)
    b = np.hstack(matrix)
    sidx = a.argsort()
    split_idx = np.flatnonzero(np.diff(a[sidx])>0)+1
    out = np.split(b[sidx], split_idx)
    return [sum(x) for x in out]
Run Code Online (Sandbox Code Playgroud)

如果你能帮助我找到一个更好,更简单的解决方案,我将不胜感激!

编辑:我犯了一个错误,完成时间在300*10矩阵中约为8毫秒,但在784x300中约为28毫秒.

EDIT2:我的A元素是float64,所以bincount给我ValueError.

use*_*203 3

您可以使用bincount这里:

\n\n
a = np.array([[0.52,0.25,-0.45,0.13],[-0.14,-0.41,0.31,-0.41]])\nb = np.array([[1,3,1,2],[3,0,2,2]])\n\nN = b.max() + 1\nid = b + (N*np.arange(b.shape[0]))[:, None] # since you can't apply bincount to a 2D array\nnp.sum(np.bincount(id.ravel(), a.ravel()).reshape(a.shape[0], -1), axis=0)\n
Run Code Online (Sandbox Code Playgroud)\n\n

输出:

\n\n
array([-0.41,  0.07,  0.03,  0.11])\n
Run Code Online (Sandbox Code Playgroud)\n\n

作为一个函数:

\n\n
def using_bincount(indices, matrx):\n    N = indices.max() + 1\n    id = indices + (N*np.arange(indices.shape[0]))[:, None] # since you can't apply bincount to a 2D array\n    return np.sum(np.bincount(id.ravel(), matrx.ravel()).reshape(matrx.shape[0], -1), axis=0)\n
Run Code Online (Sandbox Code Playgroud)\n\n

此样本的计时:

\n\n
In [5]: %timeit using_bincount(b, a)\n31.1 \xc2\xb5s \xc2\xb1 1.74 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 10000 loops each)\n\nIn [6]: %timeit matrix_sum_by_indices(b, a)\n61.3 \xc2\xb5s \xc2\xb1 2.62 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 10000 loops each)\n\nIn [88]: %timeit scipy.ndimage.sum(a, b, index=[0,1,2,3])\n54 \xc2\xb5s \xc2\xb1 218 ns per loop (mean \xc2\xb1 std. dev. of 7 runs, 10000 loops each)\n
Run Code Online (Sandbox Code Playgroud)\n\n

scipy.ndimage.sum在更大的样本上应该更快)

\n