测试:比较 numpy 数组,同时允许一定的不匹配

flo*_*lla 5 python arrays testing numpy

我有两个 numpy 数组,其中包含我正在与numpy.testing.assert_array_equal. 数组“足够相等”,即一些元素不同,但考虑到我的数组的大小,没关系(在这种特定情况下)。但当然测试失败了:

AssertionError:
Arrays are not equal

(mismatch 0.0010541406645359075%)
 x: array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],...
 y: array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],...

----------------------------------------------------------------------
Ran 1 test in 0.658s

FAILED (failures=1)
Run Code Online (Sandbox Code Playgroud)

当然,有人可能会争辩说,(长期)干净的解决方案是调整参考解决方案或诸如此类的东西,但我更喜欢简单地允许一些不匹配而不会导致测试失败。我本来希望 assert_array_equal 有一个选项,但事实并非如此。

我编写了一个函数,它允许我完全按照自己的意愿行事,因此问题可能被认为已解决,但我只是想知道是否有更好、更优雅的方法来做到这一点。此外,解析错误字符串的方法感觉很hacky,但我还没有找到更好的方法来获取不匹配百分比值。

def assert_array_equal_tolerant(arr1,arr2,threshold):
    """Compare equality of two arrays while allowing a certain mismatch.

    Arguments:
     - arr1, arr2: Arrays to compare.
     - threshold: Mismatch (in percent) above which the test fails.
    """
    try:
        np.testing.assert_array_equal(arr1,arr2)
    except AssertionError as e:
        for arg in e.args[0].split("\n"):
            match = re.search(r'mismatch ([0-9.]+)%',arg)
            if match:
                mismatch = float(match.group(1))
                break
        else:
            raise
        if mismatch > threshold:
            raise
Run Code Online (Sandbox Code Playgroud)

需要说明的是:我不是在谈论assert_array_almost_equal,使用它也是不可行的,因为错误并不小,对于单个元素来说它们可能很大,但仅限于非常少的元素。

MSe*_*ert 3

您可以尝试(如果它们是整数)在没有正则表达式的情况下检查不相等的元素数量

unequal_pos = np.where(arr1 != arr2)
len(unequal_pos[0]) # gives you the number of elements that are not equal.
Run Code Online (Sandbox Code Playgroud)

不知道你是否觉得这样更优雅。

由于 的结果np.where可以用作索引,因此您可以获得与 不匹配的元素

arr1[unequal_pos]
Run Code Online (Sandbox Code Playgroud)

因此,您可以根据该结果进行几乎所有您喜欢的测试。取决于您想要如何通过不同元素的数量或元素之间的差异或什至更奇特的东西来定义不匹配。

  • 如果您只关心不匹配的数量,那么 `np.count_nonzero(arr1 != arr2)` 是更好的选择。 (2认同)