布尔矩阵计算的最快方法

Fra*_*ser 6 python performance numpy

我有一个带有1.5E6行和20E3列的布尔矩阵,类似于这个例子:

M = [[ True,  True, False,  True, ...],
     [False,  True,  True,  True, ...],
     [False, False, False, False, ...],
     [False,  True, False, False, ...],
     ...
     [ True,  True, False, False, ...]
     ]
Run Code Online (Sandbox Code Playgroud)

另外,我还有另一个矩阵N1.5E6行、1列):

 N = [[ True],
      [False],
      [ True],
      [ True],
      ...
      [ True]
      ]
Run Code Online (Sandbox Code Playgroud)

我需要做的是通过操作符M组合的矩阵(1&1, 1&2, 1&3, 1&N, 2&1, 2&2 etc) 中的每一列对AND,并计算 result 和 matrix 之间有多少重叠N

我的 Python/Numpy 代码如下所示:

for i in range(M.shape[1]):
  for j in range(M.shape[1]):
    result = M[:,i] & M[:,j] # Combine the columns with AND operator
    count = np.sum(result & N.ravel()) # Counts the True occurrences
    ... # Save the count for the i and j pair
Run Code Online (Sandbox Code Playgroud)

问题是,通过20E3 x 20E3两个 for 循环的组合在计算上是昂贵的(大约需要5-10 天来计算)。我试过的一个更好的选择是将每一列与整个矩阵 M 进行比较:

for i in range(M.shape[1]):
  result = M[:,i]*M.shape[1] & M # np.tile or np.repeat is used to horizontally repeat the column
  counts = np.sum(result & N*M.shape[1], axis=0)
  ... # Save the counts
Run Code Online (Sandbox Code Playgroud)

这将开销和计算时间减少到 10% 左右,但仍然需要1 天左右的时间来计算。

我的问题是
进行这些计算(基本上只是ANDSUM)的最快方法是什么(可能是非 Python?)?

我在考虑低级语言、GPU 处理、量子计算等。但我对这些都不太了解,所以对方向的任何建议表示赞赏!

其他想法: 目前正在考虑是否有一种使用点积(如 Davikar 提出的)来计算组合三元组的快速方法:

def compute(M, N):
    out = np.zeros((M.shape[1], M.shape[1], M.shape[1]), np.int32)
    for i in range(M.shape[1]):
        for j in range(M.shape[1]):
            for k in range(M.shape[1]):
                result = M[:, i] & M[:, j] & M[:, k]
                out[i, j, k] = np.sum(result & N.ravel())
    return out
Run Code Online (Sandbox Code Playgroud)

Div*_*kar 8

只需使用np.einsum即可获取所有计数 -

np.einsum('ij,ik,i->jk',M,M.astype(int),N.ravel())
Run Code Online (Sandbox Code Playgroud)

随意使用optimizeflag 和np.einsum. 此外,请随意尝试不同的 dtypes 转换。

为了利用 GPU,我们可以使用tensorflow也支持einsum.

更快的替代方案np.dot

(M&N).T.dot(M.astype(int))
(M&N).T.dot(M.astype(np.float32))
Run Code Online (Sandbox Code Playgroud)

时间——

In [110]: np.random.seed(0)
     ...: M = np.random.rand(500,300)>0.5
     ...: N = np.random.rand(500,1)>0.5

In [111]: %timeit np.einsum('ij,ik,i->jk',M,M.astype(int),N.ravel())
     ...: %timeit (M&N).T.dot(M.astype(int))
     ...: %timeit (M&N).T.dot(M.astype(np.float32))
227 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
66.8 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.26 ms ± 753 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Run Code Online (Sandbox Code Playgroud)

并进一步对两个布尔数组进行 float32 转换 -

In [122]: %%timeit
     ...: p1 = (M&N).astype(np.float32)
     ...: p2 = M.astype(np.float32)
     ...: out = p1.T.dot(p2)
2.7 ms ± 34.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Run Code Online (Sandbox Code Playgroud)


jde*_*esa 6

编辑:要修复下面的代码以适合更正的问题,只需要在compute以下几个小改动:

def compute(m, n):
    m = np.asarray(m)
    n = np.asarray(n)
    # Apply mask N in advance
    m2 = m & n
    # Pack booleans into uint8 for more efficient bitwise operations
    # Also transpose for better caching (maybe?)
    mb = np.packbits(m2.T, axis=1)
    # Table with number of ones in each uint8
    num_bits = (np.arange(256)[:, np.newaxis] & (1 << np.arange(8))).astype(bool).sum(1)
    # Allocate output array
    out = np.zeros((m2.shape[1], m2.shape[1]), np.int32)
    # Do the counting with Numba
    _compute_nb(mb, num_bits, out)
    # Make output symmetric
    out = out + out.T
    # Add values in diagonal
    out[np.diag_indices_from(out)] = m2.sum(0)
    # Scale by number of ones in n
    return out
Run Code Online (Sandbox Code Playgroud)

我会用 Numba 来做这件事,使用一些技巧。首先,您只能执行一半的列操作,因为另一半是重复的。其次,您可以将布尔值打包成字节,这样每个字节都&可以操作八个而不是一个值。第三,您可以使用多处理来并行化它。总的来说,你可以这样做:

def compute(m, n):
    m = np.asarray(m)
    n = np.asarray(n)
    # Apply mask N in advance
    m2 = m & n
    # Pack booleans into uint8 for more efficient bitwise operations
    # Also transpose for better caching (maybe?)
    mb = np.packbits(m2.T, axis=1)
    # Table with number of ones in each uint8
    num_bits = (np.arange(256)[:, np.newaxis] & (1 << np.arange(8))).astype(bool).sum(1)
    # Allocate output array
    out = np.zeros((m2.shape[1], m2.shape[1]), np.int32)
    # Do the counting with Numba
    _compute_nb(mb, num_bits, out)
    # Make output symmetric
    out = out + out.T
    # Add values in diagonal
    out[np.diag_indices_from(out)] = m2.sum(0)
    # Scale by number of ones in n
    return out
Run Code Online (Sandbox Code Playgroud)

作为快速比较,这里有一个针对原始循环和 NumPy-only 方法的小基准(我很确定 Divakar 的建议是您可以从 NumPy 中得到的最好的):

import numpy as np
import numba as nb

def compute(m, n):
    m = np.asarray(m)
    n = np.asarray(n)
    # Pack booleans into uint8 for more efficient bitwise operations
    # Also transpose for better caching (maybe?)
    mb = np.packbits(m.T, axis=1)
    # Table with number of ones in each uint8
    num_bits = (np.arange(256)[:, np.newaxis] & (1 << np.arange(8))).astype(bool).sum(1)
    # Allocate output array
    out = np.zeros((m.shape[1], m.shape[1]), np.int32)
    # Do the counting with Numba
    _compute_nb(mb, num_bits, out)
    # Make output symmetric
    out = out + out.T
    # Add values in diagonal
    out[np.diag_indices_from(out)] = m.sum(0)
    # Scale by number of ones in n
    out *= n.sum()
    return out

@nb.njit(parallel=True)
def _compute_nb(mb, num_bits, out):
    # Go through each pair of columns without repetitions
    for i in nb.prange(mb.shape[0] - 1):
        for j in nb.prange(1, mb.shape[0]):
            # Count common bits
            v = 0
            for k in range(mb.shape[1]):
                v += num_bits[mb[i, k] & mb[j, k]]
            out[i, j] = v

# Test
m = np.array([[ True,  True, False,  True],
              [False,  True,  True,  True],
              [False, False, False, False],
              [False,  True, False, False],
              [ True,  True, False, False]])
n = np.array([[ True],
              [False],
              [ True],
              [ True],
              [ True]])
out = compute(m, n)
print(out)
# [[ 8  8  0  4]
#  [ 8 16  4  8]
#  [ 0  4  4  4]
#  [ 4  8  4  8]]
Run Code Online (Sandbox Code Playgroud)