为包含深度嵌套 numpy 数组的 Python 对象实现 __eq__

Mic*_*erz 5 python oop numpy

我遇到了 numpy 数组无法与对象属性上下文中的==(使用 的语义)进行比较的问题。np.array_equal

考虑以下示例:

>>> import numpy as np
>>> class A:
...     def __init__(self, a):
...         self.a = a
...     def __eq__(self, other):
...         return self.__dict__ == other.__dict__
...
>>> x = A(a=[1, np.array([1, 2])])
>>> y = A(a=[1, np.array([1, 2])])
>>> x == y
Traceback (most recent call last):
  File "<ipython-input-33-9cfbd892cdaa>", line 1, in <module>
    x == y
  File "<ipython-input-30-790950997d4f>", line 5, in __eq__
    return self.__dict__ == other.__dict__
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Run Code Online (Sandbox Code Playgroud)

(忽略这__eq__并不完美,它至少应该检查 的类型other,但这是为了简洁起见)

我将如何实现一个处理__eq__嵌套在对象属性深处的 numpy 数组的函数(假设其他所有内容,例如本例中的列表,与 相比都很好==)?numpy 数组可能出现在列表、元组或字典内任意深度的嵌套级别。

我尝试提出一个eq适用==于所有属性的递归函数的“手动”实现,并np.array_equal在遇到 numpy 数组时使用,但这比预期更棘手。

有人有合适的功能,或者简单的解决方法吗?

Sam*_*ufi 0

如果可以选择更改对象xy,您可以根据自己的喜好覆盖__eq__的方法。np.ndarray

class eqarr(np.ndarray):
    def __eq__(self, other):
        return np.array_equal(self, other)

class A:
     def __init__(self, a):
         self.a = a
     def __eq__(self, other):
         return self.__dict__ == other.__dict__

x = A(a=[1, eqarr([1, 2])])
y = A(a=[1, eqarr([1, 2])])
x == y
Run Code Online (Sandbox Code Playgroud)

这导致True.

如果这是不可能的,我目前能想到的唯一解决方案就是实际实现递归相等检查函数。我的尝试如下:

def eq(a, b):
    if not (hasattr(a, '__iter__') or type(a) == str):
        return a == b

    try:
        if not len(a) == len(b):
            return False

        if type(a) == np.ndarray:
            return np.array_equal(a, b)
        if isinstance(a, dict):
            return all(eq(v, b[k]) for k, v in a.items())
        else:
            return all(eq(aa, bb) for aa, bb in zip(a, b))
    except (TypeError, KeyError):
        return False


class A:
     def __init__(self, a):
         self.a = a
     def __eq__(self, other):
         return eq(self.__dict__, other.__dict__)
Run Code Online (Sandbox Code Playgroud)

有了你的例子和我想出的所有例子,它就起作用了。只要嵌套对象具有 an__iter__和 a__len__属性,该解决方案就应该适用。

我希望我考虑到了所有可能的错误,但您可能需要稍微调整代码以使其绝对安全。

如果您找到反例,请以评论的形式提供。我确信代码可以进行相应的调整。

性能eq可能不是很好,但我不知道这是否是您主要关心的问题。

如果 numpy 数组在您的层次结构中相当罕见(并且通常接近顶部),您始终可以先尝试正常比较。这可能如下所示:

def eq(a, b):
    try:
        return np.all(a == b)
    except ValueError:
        pass

    try:
        if not len(a) == len(b):
            return False

        if type(a) == np.ndarray:
            return np.array_equal(a, b)
        if isinstance(a, dict):
            return all(eq(v, b[k]) for k, v in a.items())
        else:
            return all(eq(aa, bb) for aa, bb in zip(a, b))
    except (TypeError, KeyError):
        return False
Run Code Online (Sandbox Code Playgroud)