我想使用cython加速以下代码:
class A(object):
cdef fun(self):
return 3
class B(object):
cdef fun(self):
return 2
def test():
cdef int x, y, i, s = 0
a = [ [A(), B()], [B(), A()]]
for i in xrange(1000):
for x in xrange(2):
for y in xrange(2):
s += a[x][y].fun()
return s
Run Code Online (Sandbox Code Playgroud)
想到的唯一事情是这样的:
def test():
cdef int x, y, i, s = 0
types = [ [0, 1], [1, 0]]
data = [[...], [...]]
for i in xrange(1000):
for x in xrange(2):
for y in xrange(2):
if types[x,y] == 0:
s+= A(data[x,y]).fun()
else:
s+= B(data[x,y]).fun()
return s
Run Code Online (Sandbox Code Playgroud)
基本上,C++中的解决方案是使用虚方法获得指向某个基类的指针数组fun(),然后您可以非常快速地迭代它.有没有办法使用python/cython?
BTW:使用numpy的2D数组与dtype = object_,而不是python列表会更快吗?
看起来像这样的代码提供了大约20倍的加速:
import numpy as np
cimport numpy as np
cdef class Base(object):
cdef int fun(self):
return -1
cdef class A(Base):
cdef int fun(self):
return 3
cdef class B(Base):
cdef int fun(self):
return 2
def test():
bbb = np.array([[A(), B()], [B(), A()]], dtype=np.object_)
cdef np.ndarray[dtype=object, ndim=2] a = bbb
cdef int i, x, y
cdef int s = 0
cdef Base u
for i in xrange(1000):
for x in xrange(2):
for y in xrange(2):
u = a[x,y]
s += u.fun()
return s
Run Code Online (Sandbox Code Playgroud)
它甚至检查A和B是否继承自Base,可能有一种方法可以在发布版本中禁用它并获得额外的加速
编辑:可以使用删除检查
u = <Base>a[x,y]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2920 次 |
| 最近记录: |