如何在cython中声明列表列表

fel*_*ipa 2 python cython

我有以下.pyx代码:

import cython
@cython.boundscheck(False)
@cython.cdivision(True)
@cython.wraparound(False)
def f(m):
  cdef int n = len(m)/2
  cdef int j, k
  z = [[0]*(n+1) for _ in range(n*(2*n-1))]
  for j in range(1, 2*n):
    for k in range(j):
      z[j*(j-1)/2+k][0] = m[j][k]
  return solve(z, 2*n, 1, [1] + [0]*n, n)


cdef solve(b, int s, int w, g, int n):
  cdef complex h
  cdef int u,v,j,k
  if s == 0:
    return w*g[n]
  c = [b[(j+1)*(j+2)/2+k+2][:] for j in range(1, s-2) for k in range(j)]
  h = solve(c, s-2, -w, g, n)
  e = g[:]
  for u in range(n):
    for v in range(n-u):
      e[u+v+1] += g[u]*b[0][v]
  for j in range(1, s-2):
    for k in range(j):
      for u in range(n):
        for v in range(n-u):
          c[j*(j-1)/2+k][u+v+1] += b[(j+1)*(j+2)/2][u]*b[(k+1)*(k+2)/2+1][v] + b[(k+1)*(k+2)/2][u]*b[(j+1)*(j+2)/2+1][v]
  return h + solve(c, s-2, w, e, n)
Run Code Online (Sandbox Code Playgroud)

我不知道如何在cython中声明列表列表,以加快代码速度。

例如,变量m是一个矩阵,表示为浮点数列表的列表。该变量z也是浮点数列表的列表。这条线应该是什么def f(m)样的?


遵循@DavidW答案中的建议,这是我的最新版本。

import cython
import numpy as np
def f(complex[:,:] m):
  cdef int n = len(m)/2
  cdef int j, k
  cdef complex[:,:] z = np.zeros((n*(2*n-1), n+1), dtype = complex)
  for j in range(1, 2*n):
    for k in range(j):
      z[j*(j-1)/2+k, 0] = m[j, k]
  return solve(z, 2*n, 1, [1] + [0]*n, n)


cdef solve(complex[:,:] b, int s, int w, g, int n):
  cdef complex h
  cdef int u,v,j,k
  cdef complex[:,:] c
  if s == 0:
    return w*g[n]
  c = [b[(j+1)*(j+2)/2+k+2][:] for j in range(1, s-2) for k in range(j)]
  print("c stats:", len(c), [len(c[i]) for i in len(c)]) 
  h = solve(c, s-2, -w, g, n)
  e = g[:]
  for u in range(n):
    for v in range(n-u):
      e[u+v+1] = e[u+v+1] + g[u]*b[0][v]
  for j in range(1, s-2):
    for k in range(j):
      for u in range(n):
        for v in range(n-u):
          c[j*(j-1)/2+k][u+v+1] = c[j*(j-1)/2+k][u+v+1] + b[(j+1)*(j+2)/2][u]*b[(k+1)*(k+2)/2+1][v] + b[(k+1)*(k+2)/2][u]*b[(j+1)*(j+2)/2+1][v]
  return h + solve(c, s-2, w, e, n)
Run Code Online (Sandbox Code Playgroud)

现在的主要问题是如何将c声明为当前的c列表列表。

Dav*_*idW 5

列表列表并不是一种可以从Cython加快速度的结构。您应该使用的结构是2D 类型的memoryview

def f(double[:,:] m):
    # ...
Run Code Online (Sandbox Code Playgroud)

这些被索引为m[j,k]而不是m[j][k]。您可以将暴露Python缓冲区协议的任何形状合适的对象传递给它们。大多数情况下,这将是一个Numpy数组。


除非您了解装饰器的功能,并考虑过它们是否适合您的功能@cython.boundscheck(False)@cython.wraparound(False)否则您还应避免使用装饰器,例如和。对于当前版本(使用lists),它们实际上什么也不做,建议您在不理解的情况下复制它们。它们确实加快了内存视图的索引编制(以一些安全为代价)。


编辑:在初始化方面,c您有两个选择。

  1. 用列表列表初始化一个numpy数组。这可能不会很快(但如果其他步骤较慢,则可能无关紧要):

    c = np.array([b[(j+1)*(j+2)/2+k+2,:] for j in range(1, s-2) for k in range(j)], dtype=complex)
    # note that I've changed the indexing of b slightly
    
    Run Code Online (Sandbox Code Playgroud)
  2. 设置c适当大小的np.zeros数组。将列表推导交换两个循环。对于我来说,这到底是什么意思,还不是100%显而易见,但这有点像

    c = np.zeros("some size you'll have to work out",dtype=complex)
    for k in range(j):
        for j in range(1,s-2):
            c["some function of j and k",:] = b["some function of j and k",:] 
    
    Run Code Online (Sandbox Code Playgroud)

您还需要替换len(c)c.shape[0],等等。