多维数组的就地改组

use*_*893 5 python arrays numpy cython multidimensional-array

我试图在Cython中实现一个NaN安全的混洗程序,它可以沿着任意维度的多维矩阵的几个轴进行混洗.

在1D矩阵的简单情况下,可以使用Fisher-Yates算法简单地将所有具有非NaN值的索引混洗:

def shuffle1D(np.ndarray[double, ndim=1] x):
    cdef np.ndarray[long, ndim=1] idx = np.where(~np.isnan(x))[0]
    cdef unsigned int i,j,n,m

    randint = np.random.randint
    for i in xrange(len(idx)-1, 0, -1):
        j = randint(i+1)
        n,m = idx[i], idx[j]
        x[n], x[m] = x[m], x[n]
Run Code Online (Sandbox Code Playgroud)

我想扩展这个算法来处理没有重新形状的大型多维数组(它触发了一个副本,用于更复杂的情况,这里没有考虑).为此,我需要摆脱固定的输入维度,这对于Cython中的numpy数组和内存视图来说似乎都不可能.有解决方法吗?

提前谢谢了!

Sau*_*tro 4

感谢 @Veedrac 的评论,这个答案使用了更多 Cython 功能。

  • 指针数组存储值的内存地址axis
  • 您的算法经过修改后会检查nan,从而防止它们被排序
  • 它不会为C有序数组创建副本。如果是Fortran有序数组,该ravel()命令将返回一个副本。这可以通过创建另一个双指针数组来携带 的值来改进x,可能会带来一些缓存惩罚......

该代码比其他基于切片的代码至少快一个数量级。

from libc.stdlib cimport malloc, free

cimport numpy as np
import numpy as np
from numpy.random import randint

cdef extern from "numpy/npy_math.h":
    bint npy_isnan(double x)

def shuffleND(x, int axis=-1):
    cdef np.ndarray[double, ndim=1] v # view of x
    cdef np.ndarray[int, ndim=1] strides
    cdef int i, j
    cdef int num_axis, pos, stride
    cdef double tmp
    cdef double **v_axis

    if axis==-1:
        axis = x.ndim-1

    shape = list(x.shape)
    num_axis = shape.pop(axis)

    v_axis = <double **>malloc(num_axis*sizeof(double *))
    for i in range(num_axis):
        v_axis[i] = <double *>malloc(1*sizeof(double))

    try:
        tmp_strides = [s//x.itemsize for s in x.strides]
        stride = tmp_strides.pop(axis)
        strides = np.array(tmp_strides, dtype=np.int32)
        v = x.ravel()
        for indices in np.ndindex(*shape):
            pos = (strides*indices).sum()
            for i in range(num_axis):
                v_axis[i] = &v[pos + i*stride]
            for i in range(num_axis-1, 0, -1):
                j = randint(i+1)
                if npy_isnan(v_axis[i][0]) or npy_isnan(v_axis[j][0]):
                    continue
                tmp = v_axis[i][0]
                v_axis[i][0] = v_axis[j][0]
                v_axis[j][0] = tmp
    finally:
        free(v_axis)

    return x
Run Code Online (Sandbox Code Playgroud)