带边界检查的 Numpy 切片

Bur*_*000 7 python numpy

numpy 是否提供了一种在对数组进行切片时进行边界检查的方法?例如,如果我这样做:

arr = np.ones([2,2])
sliced_arr = arr[0:5,:]
Run Code Online (Sandbox Code Playgroud)

这个切片没问题,即使我要求不存在的索引,它也会返回整个 arr 。如果我尝试切出数组的边界,是否有另一种切入 numpy 的方法会引发错误?

jde*_*esa 5

这最终比预期的要长一些,但是您可以编写自己的包装器来检查 get 操作,以确保切片不会超出限制(不是切片的索引参数已经由 NumPy 检查)。我想我在这里涵盖了所有情况(省略号、np.newaxis负面步骤...),尽管可能仍然存在一些失败的极端情况。

import numpy as np

# Wrapping function
def bounds_checked_slice(arr):
    return SliceBoundsChecker(arr)

# Wrapper that checks that indexing slices are within bounds of the array
class SliceBoundsChecker:

    def __init__(self, arr):
        self._arr = np.asarray(arr)

    def __getitem__(self, args):
        # Slice bounds checking
        self._check_slice_bounds(args)
        return self._arr.__getitem__(args)

    def __setitem__(self, args, value):
        # Slice bounds checking
        self._check_slice_bounds(args)
        return self._arr.__setitem__(args, value)

    # Check slices in the arguments are within bounds
    def _check_slice_bounds(self, args):
        if not isinstance(args, tuple):
            args = (args,)
        # Iterate through indexing arguments
        arr_dim = 0
        i_arg = 0
        for i_arg, arg in enumerate(args):
            if isinstance(arg, slice):
                self._check_slice(arg, arr_dim)
                arr_dim += 1
            elif arg is Ellipsis:
                break
            elif arg is np.newaxis:
                pass
            else:
                arr_dim += 1
        # Go backwards from end after ellipsis if necessary
        arr_dim = -1
        for arg in args[:i_arg:-1]:
            if isinstance(arg, slice):
                self._check_slice(arg, arr_dim)
                arr_dim -= 1
            elif arg is Ellipsis:
                raise IndexError("an index can only have a single ellipsis ('...')")
            elif arg is np.newaxis:
                pass
            else:
                arr_dim -= 1

    # Check a single slice
    def _check_slice(self, slice, axis):
        size = self._arr.shape[axis]
        start = slice.start
        stop = slice.stop
        step = slice.step if slice.step is not None else 1
        if step == 0:
            raise ValueError("slice step cannot be zero")
        bad_slice = False
        if start is not None:
            start = start if start >= 0 else start + size
            bad_slice |= start < 0 or start >= size
        if stop is not None:
            stop = stop if stop >= 0 else stop + size
            bad_slice |= (stop < 0 or stop > size) if step > 0 else (stop < 0 or stop >= size)
        if bad_slice:
            raise IndexError("slice {}:{}:{} is out of bounds for axis {} with size {}".format(
                slice.start if slice.start is not None else '',
                slice.stop if slice.stop is not None else '',
                slice.step if slice.step is not None else '',
                axis % self._arr.ndim, size))
Run Code Online (Sandbox Code Playgroud)

一个小演示:

import numpy as np

a = np.arange(24).reshape(4, 6)
print(bounds_checked_slice(a)[:2, 1:5])
# [[ 1  2  3  4]
#  [ 7  8  9 10]]
bounds_checked_slice(a)[:2, 4:10]
# IndexError: slice 4:10: is out of bounds for axis 1 with size 6
Run Code Online (Sandbox Code Playgroud)

如果您愿意,您甚至可以将其设为ndarray 的子类,这样默认情况下您就会获得此行为,而不必每次都包装数组。

另请注意,您认为的“越界”可能会有所不同。上面的代码认为,即使一个索引超出了大小也是越界的,这意味着您不能使用类似arr[len(arr):]. 如果您正在考虑稍微不同的行为,原则上您可以编辑代码。


yat*_*atu 3

如果您使用range而不是常见的切片符号,您可以获得预期的行为。例如对于有效的切片:

arr[range(2),:]

array([[1., 1.],
       [1., 1.]])
Run Code Online (Sandbox Code Playgroud)

例如,如果我们尝试切片:

arr[range(5),:]
Run Code Online (Sandbox Code Playgroud)

它会抛出以下错误:

IndexError:索引 2 超出大小 2 的范围

我对为什么会引发错误的猜测是,使用常见切片表示法进行切片是numpy数组和列表中的基本属性,因此当我们尝试使用错误的索引进行切片时,它不会抛出索引超出范围错误,而是已经考虑到了这一点并切至最接近的有效索引。然而,当使用 a 进行切片时,显然没有考虑到这一点range,因为 a 是一个不可变的对象。

  • 这确实有效,但对于大数组,人们应该记住,它在计算上比切片更昂贵,并且它将生成一个新数组而不是视图。 (4认同)