获取numpy数组中重复元素的所有索引的列表

mor*_*ens 9 python arrays django numpy

我试图在numpy数组中得到所有重复元素的索引,但我现在发现的解决方案对于大型(> 20000个元素)输入数组来说是非常低效的(它需要大约9秒钟).这个想法很简单:

  1. records_array是一个numpy时间戳数组(timedate),我们要从中提取重复时间戳的索引

  2. time_array 是一个numpy数组,包含重复的所有时间戳 records_array

  3. records是一个包含一些Record对象的django QuerySet(可以很容易地转换为列表).我们想要创建一个由Record的tagId属性的所有可能组合形成的对的列表,对应于从中找到的重复时间戳records_array.

这是我目前的工作(但效率低下)代码:

tag_couples = [];
for t in time_array:
    users_inter = np.nonzero(records_array == t)[0] # Get all repeated timestamps in records_array for time t
    l = [str(records[i].tagId) for i in users_inter] # Create a temporary list containing all tagIds recorded at time t
    if l.count(l[0]) != len(l): #remove tuples formed by the first tag repeated
        tag_couples +=[x for x in itertools.combinations(list(set(l)),2)] # Remove duplicates with list(set(l)) and append all possible couple combinations to tag_couples
Run Code Online (Sandbox Code Playgroud)

我很确定这可以通过使用Numpy进行优化,但是我找不到一种方法来比较records_array每个元素time_array而不使用for循环(这不能通过使用来比较==,因为它们都是数组).

gg3*_*349 22

像往常一样的解决方案与numpy的魔力unique(),没有循环或列表理解:

records_array = array([1, 2, 3, 1, 1, 3, 4, 3, 2])
idx_sort = argsort(records_array)
sorted_records_array = records_array[idx_sort]
vals, idx_start, count = unique(sorted_records_array, return_counts=True,
                                return_index=True)

# sets of indices
res = split(idx_sort, idx_start[1:])
#filter them with respect to their size, keeping only items occurring more than once

vals = vals[count > 1]
res = filter(lambda x: x.size > 1, res)
Run Code Online (Sandbox Code Playgroud)

编辑:以下是我以前的答案,需要更多的内存,使用numpy广播和调用unique两次:

records_array = array([1, 2, 3, 1, 1, 3, 4, 3, 2])
vals, inverse, count = unique(records_array, return_inverse=True,
                              return_counts=True)

idx_vals_repeated = where(count > 1)[0]
vals_repeated = vals[idx_vals_repeated]

rows, cols = where(inverse == idx_vals_repeated[:, newaxis])
_, inverse_rows = unique(rows, return_index=True)
res = split(cols, inverse_rows[1:])
Run Code Online (Sandbox Code Playgroud)

如预期的那样 res = [array([0, 3, 4]), array([1, 8]), array([2, 5, 7])]


Tre*_*ney 15

  • 答案很复杂,取决于数组中唯一元素的大小和数量。
  • 下列:
    • 测试具有 2M 个元素和最多 20k 个唯一元素的数组。
    • 测试最多 80k 个元素的数组,最多 20k 个唯一元素
      • 对于少于 40k 个元素的数组,测试最多只有数组大小的一半(例如,10k 个元素将有多达 5k 个唯一元素)。

具有 200 万个元素的数组

  • np.wheredefaultdict最多约 200 个唯一元素的速度快,但比pandas.core.groupby.GroupBy.indices, 和慢np.unique
  • 使用pandas, 的解决方案是大型阵列的最快解决方案。

最多 80k 个元素的数组

  • 这是更多情况,取决于数组的大小和唯一元素的数量。
  • defaultdict 对于大约 2400 个元素的数组,尤其是具有大量唯一元素的数组,是一个快速选项。
  • 对于大于 40k 个元素和 20k 个唯一元素的数组,pandas 是最快的选择。

%timeit

import random
import numpy
import pandas as pd
from collections import defaultdict

def dd(l):
    # default_dict test
    indices = defaultdict(list)
    for i, v in enumerate(l):
        indices[v].append(i)
    return indices


def npw(l):
    # np_where test
    return {v: np.where(l == v)[0] for v in np.unique(l)}


def uni(records_array):
    # np_unique test
    idx_sort = np.argsort(records_array)
    sorted_records_array = records_array[idx_sort]
    vals, idx_start, count = np.unique(sorted_records_array, return_counts=True, return_index=True)
    res = np.split(idx_sort, idx_start[1:])
    return dict(zip(vals, res))


def daf(l):
    # pandas test
    return pd.DataFrame(l).groupby([0]).indices


data = defaultdict(list)

for x in range(4, 20000, 100):  # number of unique elements
    # create 2M element list
    random.seed(365)
    a = np.array([random.choice(range(x)) for _ in range(2000000)])
    
    res1 = %timeit -r2 -n1 -q -o dd(a)
    res2 = %timeit -r2 -n1 -q -o npw(a)
    res3 = %timeit -r2 -n1 -q -o uni(a)
    res4 = %timeit -r2 -n1 -q -o daf(a)
    
    data['defaut_dict'].append(res1.average)
    data['np_where'].append(res2.average)
    data['np_unique'].append(res3.average)
    data['pandas'].append(res4.average)
    data['idx'].append(x)

df = pd.DataFrame(data)
df.set_index('idx', inplace=True)

df.plot(figsize=(12, 5), xlabel='unique samples', ylabel='average time (s)', title='%timeit test: 2 run 1 loop each')
plt.legend(bbox_to_anchor=(1.0, 1), loc='upper left')
plt.show()
Run Code Online (Sandbox Code Playgroud)

使用 200 万个元素进行测试

在此处输入图片说明

在此处输入图片说明

在此处输入图片说明

在此处输入图片说明

测试多达 80k 个元素

在此处输入图片说明

在此处输入图片说明

在此处输入图片说明

在此处输入图片说明

在此处输入图片说明

在此处输入图片说明

在此处输入图片说明

在此处输入图片说明

在此处输入图片说明

在此处输入图片说明

在此处输入图片说明

在此处输入图片说明

在此处输入图片说明

  • 请原谅我的法语但是daaaaaamn!谢谢你的努力! (2认同)