使用Numba优化int组元组的字典?

Are*_*ski 5 python cython numba

我正在学习如何使用Numba(虽然我已经非常熟悉Cython).我该怎么做才能加快这段代码的速度?请注意,该函数返回一组两元组的dict.我正在使用IPython笔记本.我更喜欢Numba而不是Cython.

@autojit
def generateadj(width,height):
    adj = {}
    for y in range(height):
        for x in range(width):
            s = set()
            if x>0:
                s.add((x-1,y))
            if x<width-1:
                s.add((x+1,y))
            if y>0:
                s.add((x,y-1))
            if y<height-1:
                s.add((x,y+1))
            adj[x,y] = s
    return adj
Run Code Online (Sandbox Code Playgroud)

我设法在Cython中写这个,但我不得不放弃数据结构的方式.我不喜欢这个.我在Numba文档中的某处读到它可以处理列表,元组等基本内容.

%%cython
import numpy as np

def generateadj(int width, int height):
    cdef int[:,:,:,:] adj = np.zeros((width,height,4,2), np.int32)
    cdef int count

    for y in range(height):
        for x in range(width):
            count = 0
            if x>0:
                adj[x,y,count,0] = x-1
                adj[x,y,count,1] = y
                count += 1
            if x<width-1:
                adj[x,y,count,0] = x+1
                adj[x,y,count,1] = y
                count += 1
            if y>0:
                adj[x,y,count,0] = x
                adj[x,y,count,1] = y-1
                count += 1
            if y<height-1:
                adj[x,y,count,0] = x
                adj[x,y,count,1] = y+1
                count += 1
            for i in range(count,4):
                adj[x,y,i] = adj[x,y,0]
    return adj
Run Code Online (Sandbox Code Playgroud)

jme*_*jme 5

虽然numba支持像dicts和sets 这样的Python数据结构,但它在对象模式下也是如此.从numba词汇表中,对象模式定义为:

Numba编译模式,生成代码,将所有值作为Python对象处理,并使用Python C API对这些对象执行所有操作.除非Numba编译器可以利用循环匹配,否则在对象模式下编译的代码通常不会比Python解释代码快.

因此,在编写numba代码时,您需要坚持使用内置数据类型,例如数组.这里有一些代码可以做到这一点:

@jit
def gen_adj_loop(width, height, adj):
    i = 0
    for x in range(width):
        for y in range(height):
            if x > 0:
                adj[i,0] = x
                adj[i,1] = y
                adj[i,2] = x - 1
                adj[i,3] = y
                i += 1

            if x < width - 1:
                adj[i,0] = x
                adj[i,1] = y
                adj[i,2] = x + 1
                adj[i,3] = y
                i += 1

            if y > 0:
                adj[i,0] = x
                adj[i,1] = y
                adj[i,2] = x
                adj[i,3] = y - 1
                i += 1

            if y < height - 1:
                adj[i,0] = x
                adj[i,1] = y
                adj[i,2] = x
                adj[i,3] = y + 1
                i += 1
    return
Run Code Online (Sandbox Code Playgroud)

这需要一个数组adj.每行都有表格x y adj_x adj_y.所以对于像素(3,4),我们有四行:

3 4 2 4
3 4 4 4
3 4 3 3
3 4 3 5
Run Code Online (Sandbox Code Playgroud)

我们可以将上面的函数包装在另一个中:

@jit
def gen_adj(width, height):
    # each pixel has four neighbors, but some of these neighbors are
    # off the grid -- 2*width + 2*height of them to be exact
    n_entries = width*height*4 - 2*width - 2*height
    adj = np.zeros((n_entries, 4), dtype=int)
    gen_adj_loop(width, height, adj)
Run Code Online (Sandbox Code Playgroud)

此功能非常快,但不完整.我们必须adj在您的问题中转换为表格的字典.问题是这是一个非常缓慢的过程.我们必须迭代adj数组并将每个条目添加到Python字典中.这不能被追究numba.

所以最重要的是:结果是元组字典的要求确实限制了你可以优化这些代码的程度.