优化/删除循环

mar*_*ako 5 python numpy networkx

我有以下代码,我想使用numpy进行优化,最好是删除循环.我看不出如何接近它,所以任何建议都会有所帮助.

indices是一个(N,2)numpy整数数组,N可以是几百万.代码的作用是在第一列中找到重复的索引.对于这些索引,我在第二列中进行了两个相应索引的所有组合.然后我将它们与第一列中的索引一起收集.

index_sets = []
uniques, counts = np.unique(indices[:,0], return_counts=True)
potentials = uniques[counts > 1]
for p in potentials:
    correspondents = indices[(indices[:,0] == p),1]
    combs = np.vstack(list(combinations(correspondents, 2)))
    combs = np.hstack((np.tile(p, (combs.shape[0], 1)), combs))
    index_sets.append(combs)
Run Code Online (Sandbox Code Playgroud)

Eel*_*orn 1

这是一个在 N 上进行向量化的解决方案。请注意,它仍然包含一个 for 循环,但它是每个“关键多重性组”上的循环,保证其数量要小得多(通常是几十个)最多)。

对于 N=1.000.000,运行时间在我的电脑上是一秒的数量级。

import numpy_indexed as npi
N = 1000000
indices = np.random.randint(0, N/10, size=(N, 2))

def combinations(x):
    """vectorized computation of combinations for an array of sequences of equal length

    Parameters
    ----------
    x : ndarray, [..., n_items]

    Returns
    -------
    ndarray, [..., n_items * (n_items - 1) / 2, 2]
    """
    return np.rollaxis(x[..., np.triu_indices(x.shape[-1], 1)], -2, x.ndim+1)

def process(indices):
    """process a subgroup of indices, all having equal multiplicity

    Parameters
    ----------
    indices : ndarray, [n, 2]

    Returns
    -------
    ndarray, [m, 3]
    """
    keys, vals = npi.group_by(indices[:, 0], indices[:, 1])
    combs = combinations(vals)
    keys = np.repeat(keys, combs.shape[1])
    return np.concatenate([keys[:, None], combs.reshape(-1, 2)], axis=1)

index_groups = npi.group_by(npi.multiplicity(indices[:, 0])).split(indices)
result = np.concatenate([process(ind) for ind in index_groups])
Run Code Online (Sandbox Code Playgroud)

免责声明:我是numpy_indexed包的作者。