遮罩二维数组保留形状

use*_*040 6 python numpy slice

我有这样的二维 numpy 数组:

arr = np.array([[1,2,4],
                [2,1,1],
                [1,2,3]])
Run Code Online (Sandbox Code Playgroud)

和一个布尔数组:

boolarr = np.array([[True, True, False],
                    [False, False, True],
                    [True, True,True]])
Run Code Online (Sandbox Code Playgroud)

现在,当我尝试根据 boolarr 对 arr 进行切片时,它给了我

arr[boolarr]
Run Code Online (Sandbox Code Playgroud)

输出:

array([1, 2, 1, 1, 2, 3])
Run Code Online (Sandbox Code Playgroud)

但我希望有一个二维数组输出。所需的输出是

[[1, 2],
 [1],
 [1, 2, 3]]
Run Code Online (Sandbox Code Playgroud)

yat*_*atu 5

使用的一个选项numpy是首先添加以下行mask

take = boolarr.sum(axis=1)
#array([2, 1, 3])
Run Code Online (Sandbox Code Playgroud)

然后像你一样屏蔽数组:

x = arr[boolarr]
#array([1, 2, 1, 1, 2, 3])
Run Code Online (Sandbox Code Playgroud)

并根据ofnp.split来分割平面数组(因为函数期望分割数组的索引):np.cumsumtake

np.split(x, np.cumsum(take)[:-1])
[array([1, 2]), array([1]), array([1, 2, 3])]
Run Code Online (Sandbox Code Playgroud)

通用解决方案

def mask_nd(x, m):
    '''
    Mask a 2D array and preserve the
    dimension on the resulting array
    ----------
    x: np.array
       2D array on which to apply a mask
    m: np.array
        2D boolean mask  
    Returns
    -------
    List of arrays. Each array contains the
    elements from the rows in x once masked.
    If no elements in a row are selected the 
    corresponding array will be empty
    '''
    take = m.sum(axis=1)
    return np.split(x[m], np.cumsum(take)[:-1])
Run Code Online (Sandbox Code Playgroud)

例子

让我们看一些例子:

arr = np.array([[1,2,4],
                [2,1,1],
                [1,2,3]])

boolarr = np.array([[True, True, False],
                    [False, False, False],
                    [True, True,True]])

mask_nd(arr, boolarr)
# [array([1, 2]), array([], dtype=int32), array([1, 2, 3])]
Run Code Online (Sandbox Code Playgroud)

或者对于以下数组:

arr = np.array([[1,2],
                [2,1]])

boolarr = np.array([[True, True],
                    [True, False]])

mask_nd(arr, boolarr)
# [array([1, 2]), array([2])]
Run Code Online (Sandbox Code Playgroud)