比较两个字典与numpy矩阵作为值

phy*_*ion 10 python dictionary numpy equality

我想断言两个Python字典是相等的(这意味着:等量的密钥,每个从键到值的映射是相等的;顺序并不重要).assert A==B然而,一种简单的方法是,如果字典的值是,则这不起作用numpy arrays.如果两个词典相同,我怎样才能编写一个函数来检查?

>>> import numpy as np
>>> A = {1: np.identity(5)}
>>> B = {1: np.identity(5) + np.ones([5,5])}
>>> A == B
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Run Code Online (Sandbox Code Playgroud)

编辑我知道应该检查numpy矩阵是否相等.all().我正在寻找的是检查这一点的一般方法,而无需检查isinstance(np.ndarray).这可能吗?

没有numpy数组的相关主题:

vit*_*ral 16

您可以使用 numpy.testing.assert_equal

http://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_equal.html

  • 这是正确的答案。谢谢你! (4认同)
  • @GuilhermedeLazari原始问题是:“我想断言两个Python字典是相等的” (3认同)

And*_*eak 8

我将回答隐藏在您的问题标题和前半部分中的一半问题,因为坦率地说,这是一个需要解决的更常见的问题,现有的答案并没有很好地解决它。这个问题是“如何比较两个 numpy 数组的字典是否相等”?

问题的第一部分是“从远处”检查字典:查看它们的键是否相同。如果所有键都相同,则第二部分比较每个对应的值。

现在微妙的问题是许多 numpy 数组不是整数值,并且双精度是不精确的。因此,除非您有整数值(或其他非浮点型)数组,否则您可能需要检查这些值是否几乎相同,即在机器精度内。因此,在这种情况下,您不会使用np.array_equal(它检查精确的数值相等性),而是np.allclose(它对两个数组之间的相对和绝对误差使用有限容差)。

问题的前一个半部分很简单:检查字典的键是否一致,并使用生成器推导式比较每个值(并在推导式all之外使用以验证每个项目是否相同):

import numpy as np

# some dummy data

# these are equal exactly
dct1 = {'a': np.array([2, 3, 4])}
dct2 = {'a': np.array([2, 3, 4])}

# these are equal _roughly_
dct3 = {'b': np.array([42.0, 0.2])}
dct4 = {'b': np.array([42.0, 3*0.1 - 0.1])}  # still 0.2, right?

def compare_exact(first, second):
    """Return whether two dicts of arrays are exactly equal"""
    if first.keys() != second.keys():
        return False
    return all(np.array_equal(first[key], second[key]) for key in first)

def compare_approximate(first, second):
    """Return whether two dicts of arrays are roughly equal"""
    if first.keys() != second.keys():
        return False
    return all(np.allclose(first[key], second[key]) for key in first)

# let's try them:
print(compare_exact(dct1, dct2))  # True
print(compare_exact(dct3, dct4))  # False
print(compare_approximate(dct3, dct4))  # True
Run Code Online (Sandbox Code Playgroud)

正如你在上面的例子中看到的,整数数组完全可以比较,并且取决于你在做什么(或者如果你很幸运)它甚至可以用于浮点数。但是,如果您的浮点数是任何算术运算的结果(例如线性变换?),您绝对应该使用近似检查。有关后一个选项的完整说明,请参阅(及其元素朋友)的文档numpy.allclosenumpy.isclose,特别注意rtolatol关键字参数。


sir*_*ark -3

考虑这段代码

>>> import numpy as np
>>> np.identity(5)
array([[ 1.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.],
       [ 0.,  0.,  0.,  0.,  1.]])
>>> np.identity(5)+np.ones([5,5])
array([[ 2.,  1.,  1.,  1.,  1.],
       [ 1.,  2.,  1.,  1.,  1.],
       [ 1.,  1.,  2.,  1.,  1.],
       [ 1.,  1.,  1.,  2.,  1.],
       [ 1.,  1.,  1.,  1.,  2.]])
>>> np.identity(5) == np.identity(5)+np.ones([5,5])
array([[False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False]], dtype=bool)
>>> 
Run Code Online (Sandbox Code Playgroud)

请注意,比较的结果是一个矩阵,而不是布尔值。字典比较将使用值cmp方法来比较值,这意味着在比较矩阵值时,字典比较将得到复合结果。您想要做的是使用 numpy.all将复合数组结果折叠为标量布尔结果

>>> np.all(np.identity(5) == np.identity(5)+np.ones([5,5]))
False
>>> np.all(np.identity(5) == np.identity(5))
True
>>> 
Run Code Online (Sandbox Code Playgroud)

您需要编写自己的函数来比较这些字典,测试值类型以查看它们是否是矩阵,然后使用 进行比较numpy.all,否则使用==。当然,如果您愿意,您也可以随时开始对 dict 进行子类化并重载cmp