小编use*_*496的帖子

使用Numba进行Numpy优化

我在球体上有两组点,在下面的代码示例中标记为'obj'和'ps'.我想确定所有'obj'点比'ps'点的某个角距离更近.

我对此的看法是用3D单位向量表示每个点,并将它们的点积与cos(最大间距)进行比较.这可以通过numpy广播轻松完成,但在我的应用程序中我有n_obj~500,000和n_ps~50,000,因此广播的内存要求太大.下面我使用numba粘贴了我目前的拍摄.这可以进一步优化吗?

from numba import jit
import numpy as np
from sklearn.preprocessing import normalize

def gen_points(n):
    """
    generate random 3D unit vectors (not uniform, but irrelevant here)
    """
    vec = 2*np.random.rand(n,3)-1.
    vec_norm = normalize(vec)
    return vec_norm

#@jit(nopython=True)
@jit
def angdist_threshold_numba(vec_obj,vec_ps,cos_maxsep):
    """
    finds obj that are closer than maxsep to a ps
    """    
    nps = len(vec_ps)
    nobj = len(vec_obj)     

    #closeobj_all = []
    closeobj_all = np.empty(0)
    dotprod = np.empty(nobj)
    a = np.arange(nobj)
    for ps in range(nps):
        np.sum(vec_obj*vec_ps[ps],axis=1,out=dotprod)
        #closeobj_all.extend(a[dotprod > cos_maxsep])
        closeobj_all = …
Run Code Online (Sandbox Code Playgroud)

python numpy numba

7
推荐指数
1
解决办法
3049
查看次数

Cython:使prange并行化线程安全

Cython首发在这里.我试图通过使用多个线程来加速某个成对统计(在几个箱中)的计算.特别是,我使用cython.parallel中的prange,它在内部使用openMP.

以下最小的例子说明了问题(通过Jupyter笔记本Cython魔术编译).

笔记本设置:

%load_ext Cython
import numpy as np
Run Code Online (Sandbox Code Playgroud)

Cython代码:

%%cython --compile-args=-fopenmp --link-args=-fopenmp -a

from cython cimport boundscheck
import numpy as np
from cython.parallel cimport prange, parallel

@boundscheck(False)
def my_parallel_statistic(double[:] X, double[:,::1] bins, int num_threads):

    cdef: 
        int N = X.shape[0]
        int nbins = bins.shape[0]
        double Xij,Yij
        double[:] Z = np.zeros(nbins,dtype=np.float64)
        int i,j,b

    with nogil, parallel(num_threads=num_threads):
        for i in prange(N,schedule='static',chunksize=1):
            for j in range(i):
                #some pairwise quantities
                Xij = X[i]-X[j]
                Yij = 0.5*(X[i]+X[j])
                #check if in bin
                for b in range(nbins): …
Run Code Online (Sandbox Code Playgroud)

multithreading openmp thread-safety cython

4
推荐指数
1
解决办法
995
查看次数

标签 统计

cython ×1

multithreading ×1

numba ×1

numpy ×1

openmp ×1

python ×1

thread-safety ×1