我遇到了 numpy 数组无法与对象属性上下文中的==(使用 的语义)进行比较的问题。np.array_equal
考虑以下示例:
>>> import numpy as np
>>> class A:
... def __init__(self, a):
... self.a = a
... def __eq__(self, other):
... return self.__dict__ == other.__dict__
...
>>> x = A(a=[1, np.array([1, 2])])
>>> y = A(a=[1, np.array([1, 2])])
>>> x == y
Traceback (most recent call last):
File "<ipython-input-33-9cfbd892cdaa>", line 1, in <module>
x == y
File "<ipython-input-30-790950997d4f>", line 5, in __eq__
return self.__dict__ == other.__dict__
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Run Code Online (Sandbox Code Playgroud)
(忽略这__eq__并不完美,它至少应该检查 的类型other,但这是为了简洁起见)
我将如何实现一个处理__eq__嵌套在对象属性深处的 numpy 数组的函数(假设其他所有内容,例如本例中的列表,与 相比都很好==)?numpy 数组可能出现在列表、元组或字典内任意深度的嵌套级别。
我尝试提出一个eq适用==于所有属性的递归函数的“手动”实现,并np.array_equal在遇到 numpy 数组时使用,但这比预期更棘手。
有人有合适的功能,或者简单的解决方法吗?
如果可以选择更改对象x和y,您可以根据自己的喜好覆盖__eq__的方法。np.ndarray
class eqarr(np.ndarray):
def __eq__(self, other):
return np.array_equal(self, other)
class A:
def __init__(self, a):
self.a = a
def __eq__(self, other):
return self.__dict__ == other.__dict__
x = A(a=[1, eqarr([1, 2])])
y = A(a=[1, eqarr([1, 2])])
x == y
Run Code Online (Sandbox Code Playgroud)
这导致True.
如果这是不可能的,我目前能想到的唯一解决方案就是实际实现递归相等检查函数。我的尝试如下:
def eq(a, b):
if not (hasattr(a, '__iter__') or type(a) == str):
return a == b
try:
if not len(a) == len(b):
return False
if type(a) == np.ndarray:
return np.array_equal(a, b)
if isinstance(a, dict):
return all(eq(v, b[k]) for k, v in a.items())
else:
return all(eq(aa, bb) for aa, bb in zip(a, b))
except (TypeError, KeyError):
return False
class A:
def __init__(self, a):
self.a = a
def __eq__(self, other):
return eq(self.__dict__, other.__dict__)
Run Code Online (Sandbox Code Playgroud)
有了你的例子和我想出的所有例子,它就起作用了。只要嵌套对象具有 an__iter__和 a__len__属性,该解决方案就应该适用。
我希望我考虑到了所有可能的错误,但您可能需要稍微调整代码以使其绝对安全。
如果您找到反例,请以评论的形式提供。我确信代码可以进行相应的调整。
性能eq可能不是很好,但我不知道这是否是您主要关心的问题。
如果 numpy 数组在您的层次结构中相当罕见(并且通常接近顶部),您始终可以先尝试正常比较。这可能如下所示:
def eq(a, b):
try:
return np.all(a == b)
except ValueError:
pass
try:
if not len(a) == len(b):
return False
if type(a) == np.ndarray:
return np.array_equal(a, b)
if isinstance(a, dict):
return all(eq(v, b[k]) for k, v in a.items())
else:
return all(eq(aa, bb) for aa, bb in zip(a, b))
except (TypeError, KeyError):
return False
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
855 次 |
| 最近记录: |