Lev*_*sky 20 python unit-testing numpy
我使用Python的unittest
模块,并想检查两个复杂的数据结构是否相等.对象可以是具有各种值的dicts列表:数字,字符串,Python容器(列表/元组/ dicts)和numpy
数组.后者是提出问题的原因,因为我不能这样做
self.assertEqual(big_struct1, big_struct2)
Run Code Online (Sandbox Code Playgroud)
因为它会产生一个
ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()
Run Code Online (Sandbox Code Playgroud)
我想我需要为此编写自己的相等测试.它应该适用于任意结构.我目前的想法是递归函数:
arg1
到相应节点arg2
;ValueError
被抓住,会更深入,直到找到numpy.array
;看起来有点问题的是跟踪两个结构的"相应"节点,但也许这zip
就是我所需要的.
问题是:这种方法有更好(更简单)的替代方案吗?也许numpy
为此提供了一些工具?如果没有建议的替代方案,我将实施这个想法(除非我有一个更好的想法)并发布作为答案.
PS我有一种模糊的感觉,我可能已经看到了解决这个问题的问题,但我现在找不到它.
PPS另一种方法是遍历结构并将所有numpy.array
s 转换为列表的函数,但是这更容易实现吗?对我来说似乎一样.
编辑:子类化numpy.ndarray
听起来很有希望,但显然我没有将比较的两面硬编码到测试中.但其中一个确实是硬编码的,所以我可以:
numpy.array
;isinstance(other, SaneEqualityArray)
以isinstance(other, np.ndarray)
在jterrace的答案 ;我在这方面的问题是:
numpy
数组的结构).编辑2:我试了一下,(看似)工作实现在这个答案中显示.
seb*_*erg 12
会有评论,但它太长了......
有趣的是,你不能==
用来测试阵列是否与我建议你使用的相同np.testing.assert_array_equal
.
(float('nan') == float('nan')) == False
(Python的序列==
具有忽略此的一个更有趣的方式有时,因为它使用PyObject_RichCompareBool
它做了(NaN的不正确的)is
快速检查(当然,这是完美的)的测试.. .assert_allclose
原因是,如果进行实际计算并且通常需要几乎相同的值,浮点相等性会变得非常棘手,因为这些值可能会依赖于硬件,也可能是随机的,这取决于您对它们的处理方式.我几乎建议尝试使用pickle进行序列化,如果你想要这种疯狂嵌套的东西,但这是非常严格的(当然,第3点当然完全破坏了),例如你的数组的内存布局并不重要,但对它来说很重要序列化.
该assertEqual
函数将调用__eq__
对象的方法,该方法应该递归复杂的数据类型.例外是numpy,它没有一个理智的__eq__
方法.使用此问题的numpy子类,您可以恢复相等行为的健全性:
import copy
import numpy
import unittest
class SaneEqualityArray(numpy.ndarray):
def __eq__(self, other):
return (isinstance(other, SaneEqualityArray) and
self.shape == other.shape and
numpy.ndarray.__eq__(self, other).all())
class TestAsserts(unittest.TestCase):
def testAssert(self):
tests = [
[1, 2],
{'foo': 2},
[2, 'foo', {'d': 4}],
SaneEqualityArray([1, 2]),
{'foo': {'hey': SaneEqualityArray([2, 3])}},
[{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
SaneEqualityArray([5, 6]), 34]
]
for t in tests:
self.assertEqual(t, copy.deepcopy(t))
if __name__ == '__main__':
unittest.main()
Run Code Online (Sandbox Code Playgroud)
这个测试通过.
class SaneEqualityArray(np.ndarray):
def __eq__(self, other):
return (isinstance(other, np.ndarray) and self.shape == other.shape and
np.allclose(self, other))
Run Code Online (Sandbox Code Playgroud)
就像我说的那样,带有这些对象的容器应该在等式检查的左侧.我SaneEqualityArray
从现有的numpy.ndarray
s 创建对象,如下所示:
SaneEqualityArray(my_array.shape, my_array.dtype, my_array)
Run Code Online (Sandbox Code Playgroud)
按照ndarray
构造函数签名:
ndarray(shape, dtype=float, buffer=None, offset=0,
strides=None, order=None)
Run Code Online (Sandbox Code Playgroud)
此类在测试套件中定义,仅用于测试目的.等式检查的RHS是被测试函数返回的实际对象,包含真实numpy.ndarray
对象.
PS感谢到目前为止发布的两个答案的作者,他们都非常有帮助.如果有人发现这种方法有任何问题,我将非常感谢您的反馈.
归档时间: |
|
查看次数: |
17618 次 |
最近记录: |