如何检查 numpy 数组列表是否包含给定的测试数组?

Gul*_*wak 4 python numpy

我有一个numpy数组列表,比如说,

a = [np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)]
Run Code Online (Sandbox Code Playgroud)

我有一个测试数组,比如说

b = np.random.rand(3, 3)
Run Code Online (Sandbox Code Playgroud)

我想检查是否a包含b。然而

b in a 
Run Code Online (Sandbox Code Playgroud)

引发以下错误:

ValueError:包含多个元素的数组的真值不明确。使用 a.any() 或 a.all()

我想要的正确方法是什么?

Nil*_*ner 5

你可以只让形状的一个阵列(3, 3, 3)出来的a

a = np.asarray(a)
Run Code Online (Sandbox Code Playgroud)

然后将它与b(我们在这里比较浮点数,所以我们应该使用isclose()

np.all(np.isclose(a, b), axis=(1, 2))
Run Code Online (Sandbox Code Playgroud)

例如:

a = [np.random.rand(3,3),np.random.rand(3,3),np.random.rand(3,3)]
a = np.asarray(a)
b = a[1, ...]       # set b to some value we know will yield True

np.all(np.isclose(a, b), axis=(1, 2))
# array([False,  True, False])
Run Code Online (Sandbox Code Playgroud)