ajw*_*ood 8 python optimization numpy
我有一个包含标签的numpy数组.我想根据每个标签的大小和边界框计算一个数字.如何更有效地编写这个,以便在大型数组(约15000个标签)上使用它是否真实?
A = array([[ 1, 1, 0, 3, 3],
[ 1, 1, 0, 0, 0],
[ 1, 0, 0, 2, 2],
[ 1, 0, 2, 2, 2]] )
B = zeros( 4 )
for label in range(1, 4):
# get the bounding box of the label
label_points = argwhere( A == label )
(y0, x0), (y1, x1) = label_points.min(0), label_points.max(0) + 1
# assume I've computed the size of each label in a numpy array size_A
B[ label ] = myfunc(y0, x0, y1, x1, size_A[label])
Run Code Online (Sandbox Code Playgroud)
我真的无法使用一些NumPy向量化函数有效地实现这一点,所以也许一个聪明的Python实现会更快.
def first_row(a, labels):
d = {}
d_setdefault = d.setdefault
len_ = len
num_labels = len_(labels)
for i, row in enumerate(a):
for label in row:
d_setdefault(label, i)
if len_(d) == num_labels:
break
return d
Run Code Online (Sandbox Code Playgroud)
这个函数返回一个字典映射每个标签出现在第一行的索引.应用的功能A,A.T,A[::-1]并且A.T[::-1]也给你的第一列和最后一排和列.
如果您更喜欢列表而不是字典,可以使用将字典转换为列表map(d.get, labels).或者,您可以从一开始就使用NumPy数组而不是字典,但是一旦找到所有标签,您将无法提前离开循环.
我对是否(以及多少)实际加速您的代码感兴趣,但我相信它比原始解决方案更快.
算法:
对于大型阵列如(7000,9000),可以在30s内完成计算.
这是代码:
import numpy as np
A = np.array([[ 1, 1, 0, 3, 3],
[ 1, 1, 0, 0, 0],
[ 1, 0, 0, 2, 2],
[ 1, 0, 2, 2, 2]] )
def label_range(A):
from itertools import izip_longest
h, w = A.shape
tmp = A.reshape(-1)
index = np.argsort(tmp)
sorted_A = tmp[index]
pos = np.where(np.diff(sorted_A))[0]+1
for p1,p2 in izip_longest(pos,pos[1:]):
label_index = index[p1:p2]
y = label_index // w
x = label_index % w
x0 = np.min(x)
x1 = np.max(x)+1
y0 = np.min(y)
y1 = np.max(y)+1
label = tmp[label_index[0]]
yield label,x0,y0,x1,y1
for label,x0,y0,x1,y1 in label_range(A):
print "%d:(%d,%d)-(%d,%d)" % (label, x0,y0,x1,y1)
#B = np.random.randint(0, 100, (7000, 9000))
#list(label_range(B))
Run Code Online (Sandbox Code Playgroud)
另一种方法:
使用bincount()获取每行和每列中的标签计数,并将信息保存在rows和cols数组中.
对于每个标签,您只需要在行和列中搜索范围.它比排序更快,在我的电脑上,它可以在几秒钟内完成计算.
def label_range2(A):
maxlabel = np.max(A)+1
h, w = A.shape
rows = np.zeros((h, maxlabel), np.bool)
for row in xrange(h):
rows[row,:] = np.bincount(A[row,:], minlength=maxlabel) > 0
cols = np.zeros((w, maxlabel), np.bool)
for col in xrange(w):
cols[col,:] =np.bincount(A[:,col], minlength=maxlabel) > 0
for label in xrange(1, maxlabel):
row = rows[:, label]
col = cols[:, label]
y = np.where(row)[0]
x = np.where(col)[0]
x0 = np.min(x)
x1 = np.max(x)+1
y0 = np.min(y)
y1 = np.max(y)+1
yield label, x0,y0,x1,y1
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
392 次 |
| 最近记录: |