什么是更快捷的方式来获取numpy中唯一行的位置

b10*_*ard 7 python numpy scipy

我有一个唯一行列表和另一个更大的数据数组(在示例中称为test_rows).我想知道是否有更快的方法来获取数据中每个唯一行的位置.我能想到的最快的方法是......

import numpy


uniq_rows = numpy.array([[0, 1, 0],
                         [1, 1, 0],
                         [1, 1, 1],
                         [0, 1, 1]])

test_rows = numpy.array([[0, 1, 1],
                         [0, 1, 0],
                         [0, 0, 0],
                         [1, 1, 0],
                         [0, 1, 0],
                         [0, 1, 1],
                         [0, 1, 1],
                         [1, 1, 1],
                         [1, 1, 0],
                         [1, 1, 1],
                         [0, 1, 0],
                         [0, 0, 0],
                         [1, 1, 0]])

# this gives me the indexes of each group of unique rows
for row in uniq_rows.tolist():
    print row, numpy.where((test_rows == row).all(axis=1))[0]
Run Code Online (Sandbox Code Playgroud)

这打印......

[0, 1, 0] [ 1  4 10]
[1, 1, 0] [ 3  8 12]
[1, 1, 1] [7 9]
[0, 1, 1] [0 5 6]
Run Code Online (Sandbox Code Playgroud)

是否有更好或更多的numpythonic(不确定该词是否存在)这样做的方法?我正在寻找一个numpy组功能,但找不到它.基本上对于任何传入的数据集,我需要以最快的方式获取该数据集中每个唯一行的位置.传入数据集并不总是具有每个唯一行或相同的数字.

编辑:这只是一个简单的例子.在我的应用程序中,数字不仅仅是0和32,它们可以是0到32000之间.unityq行的大小可以在4到128行之间,test_rows的大小可以是数十万.

Div*_*kar 0

方法#1

这是一种方法,尽管对于这样一个棘手的问题,不确定“NumPythonic-ness”的水平 -

def get1Ds(a, b): # Get 1D views of each row from the two inputs
    # check that casting to void will create equal size elements
    assert a.shape[1:] == b.shape[1:]
    assert a.dtype == b.dtype

    # compute dtypes
    void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))

    # convert to 1d void arrays
    a = np.ascontiguousarray(a)
    b = np.ascontiguousarray(b)
    a_void = a.reshape(a.shape[0], -1).view(void_dt).ravel()
    b_void = b.reshape(b.shape[0], -1).view(void_dt).ravel()
    return a_void, b_void

def matching_row_indices(uniq_rows, test_rows):
    A, B = get1Ds(uniq_rows, test_rows)
    validA_mask = np.in1d(A,B)

    sidx_A = A.argsort()
    validA_mask = validA_mask[sidx_A]    

    sidx = B.argsort()
    sortedB = B[sidx]
    split_idx = np.flatnonzero(sortedB[1:] != sortedB[:-1])+1
    all_split_indx = np.split(sidx, split_idx)

    match_mask = np.in1d(B,A)[sidx]
    valid_mask = np.logical_or.reduceat(match_mask, np.r_[0, split_idx])    
    locations = [e for i,e in enumerate(all_split_indx) if valid_mask[i]]

    return uniq_rows[sidx_A[validA_mask]], locations    
Run Code Online (Sandbox Code Playgroud)

改进范围(性能):

  1. np.split可以替换为 for 循环以使用 进行分割slicing
  2. np.r_可以替换为np.concatenate.

样本运行 -

In [331]: unq_rows, idx = matching_row_indices(uniq_rows, test_rows)

In [332]: unq_rows
Out[332]: 
array([[0, 1, 0],
       [0, 1, 1],
       [1, 1, 0],
       [1, 1, 1]])

In [333]: idx
Out[333]: [array([ 1,  4, 10]),array([0, 5, 6]),array([ 3,  8, 12]),array([7, 9])]
Run Code Online (Sandbox Code Playgroud)

方法#2

另一种克服前一种方法的设置开销并利用get1Ds它的方法是 -

A, B = get1Ds(uniq_rows, test_rows)
idx_group = []
for row in A:
    idx_group.append(np.flatnonzero(B == row))
Run Code Online (Sandbox Code Playgroud)