使用自定义 dtype 按行对对象数组进行排序

Mad*_*ist 6 python sorting numpy

我试图按行按字典顺序对一些数组进行排序。整数情况完美地工作:

>>> arr = np.random.choice(10, size=(5, 3))
>>> arr
array([[1, 0, 2],
       [8, 0, 8],
       [1, 8, 4],
       [1, 3, 9],
       [6, 1, 8]])
>>> np.ndarray(arr.shape[0], dtype=[('', arr.dtype, arr.shape[1])], buffer=arr).sort()
>>> arr
array([[1, 0, 2],
       [1, 3, 9],
       [1, 8, 4],
       [6, 1, 8],
       [8, 0, 8]])
Run Code Online (Sandbox Code Playgroud)

我也可以用

np.ndarray(arr.shape[0], dtype=[('', arr.dtype)] * arr.shape[1], buffer=arr).sort()
Run Code Online (Sandbox Code Playgroud)

在这两种情况下,结果是相同的。但是,对象数组不是这种情况:

>>> selection = np.array(list(string.ascii_lowercase), dtype=object)
>>> arr = np.random.choice(selection, size=(5, 3))
>>> arr
array([['t', 'p', 'g'],
       ['n', 's', 'd'],
       ['g', 'g', 'n'],
       ['g', 'h', 'o'],
       ['f', 'j', 'x']], dtype=object)
>>> np.ndarray(arr.shape[0], dtype=[('', arr.dtype, arr.shape[1])], buffer=arr).sort()
>>> arr
array([['t', 'p', 'g'],
       ['n', 's', 'd'],
       ['g', 'h', 'o'],
       ['g', 'g', 'n'],
       ['f', 'j', 'x']], dtype=object)
>>> np.ndarray(arr.shape[0], dtype=[('', arr.dtype)] * arr.shape[1], buffer=arr).sort()
>>> arr
array([['f', 'j', 'x'],
       ['g', 'g', 'n'],
       ['g', 'h', 'o'],
       ['n', 's', 'd'],
       ['t', 'p', 'g']], dtype=object)
Run Code Online (Sandbox Code Playgroud)

显然,只有 的情况下才能dtype=[('', arr.dtype)] * arr.shape[1]正常工作。这是为什么?有什么不同dtype=[('', arr.dtype, arr.shape[1])]?排序显然是在做某事,但乍一看这个顺序似乎是荒谬的。它是否使用指针作为排序键?

就其价值而言,正如预期的那样,np.searchsorted似乎正在进行与 相同的比较np.sort

Mad*_*ist 0

事实上,对整数进行排序恰好是一个巧合,这可以通过查看浮点运算的结果来验证:

>>> arr = np.array([[0.5, 1.0, 10.2],
                    [0.4, 2.0, 11.0],
                    [1.0, 2.0, 4.0]])
>>> np.sort(np.ndarray(arr.shape[0], dtype=[('', arr.dtype, arr.shape[1])], buffer=arr))
array([([ 0.5,  1. , 10.2],),
       ([ 1. ,  2. ,  4. ],),
       ([ 0.4,  2. , 11. ],)], dtype=[('f0', '<f8', (3,))])
>>> np.sort(np.ndarray(arr.shape[0], dtype=[('', arr.dtype)] * arr.shape[1], buffer=arr))
array([(0.4, 2., 11. ),
       (0.5, 1., 10.2),
       (1. , 2.,  4. )],
      dtype=[('f0', '<f8'), ('f1', '<f8'), ('f2', '<f8')])
Run Code Online (Sandbox Code Playgroud)

另一个提示来自于查看数字0.50.4和 的位1.0

0.5 = 0x3FE0000000000000
0.4 = 0x3FD999999999999A
1.0 = 0x3FF6666666666666
Run Code Online (Sandbox Code Playgroud)

在小端机器上,我们有这个0x00 < 0x66 < 0x9A(上面显示的最后一个字节在前面)。

确切的答案可以通过查看源代码中的排序函数来验证。例如,在 中quicksort.c.src,我们看到所有非显式数字的类型(包括非标量的结构字段)均由泛型函数处理npy_quicksort。它使用函数cmp作为比较器,GENERIC_SWAP并使用宏GENERIC_COPY和分别进行交换和复制。

该函数cmp定义为PyArray_DESCR(arr)->f->compare。宏在 中定义为逐元素操作npysort_common.h

所以最终的结果是,对于任何非标量类型,包括打包数组结构字段,比较都是逐字节完成的。对于对象,这当然是指针的数值。对于浮点数,这将是 IEEE-754 表示形式。正整数似乎可以正常工作的事实是由于我的平台使用小端编码。以二进制补码形式存储的负整数可能不会产生正确的结果。