2D 或 3D 中的 Numpy trim_zeros

ale*_*lex 6 python arrays numpy trim multidimensional-array

如何从 NumPy 数组中删除前导/尾随零?Trim_zeros仅适用于一维。

Bil*_*ill 6

这是一些处理二维数组的代码。

import numpy as np

# Arbitrary array
arr = np.array([
    [0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0],
    [0, 1, 1, 1, 0],
    [0, 1, 0, 1, 0],
    [1, 1, 0, 1, 0],
    [1, 0, 0, 1, 0],
    [0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0]
])

nz = np.nonzero(arr)  # Indices of all nonzero elements
arr_trimmed = arr[nz[0].min():nz[0].max()+1,
                  nz[1].min():nz[1].max()+1]

assert np.array_equal(arr_trimmed, [
         [0, 0, 0, 1],
         [0, 1, 1, 1],
         [0, 1, 0, 1],
         [1, 1, 0, 1],
         [1, 0, 0, 1],
    ])
Run Code Online (Sandbox Code Playgroud)

这可以推广到 N 维,如下所示:

def trim_zeros(arr):
    """Returns a trimmed view of an n-D array excluding any outer
    regions which contain only zeros.
    """
    slices = tuple(slice(idx.min(), idx.max() + 1) for idx in np.nonzero(arr))
    return arr[slices]

test = np.zeros((5,5,5,5))
test[1:3,1:3,1:3,1:3] = 1
trimmed_array = trim_zeros(test)
assert trimmed_array.shape == (2, 2, 2, 2)
assert trimmed_array.sum() == 2**4
Run Code Online (Sandbox Code Playgroud)