Numpy where 可广播条件

PKl*_*mpp 6 python numpy where-clause

numpy.where()现在已经使用过很多次了,我总是想知道文档中的以下声明:

x、y 和条件需要可广播为某种形状。

我明白为什么这对于xy都是必要的。我们想要将这两个数组组合起来,因此它们应该可以广播为相同的形状。然而,我不明白为什么这对于这种情况也如此重要。这只是决策规则。假设我有以下三种形状:

condition = (100,)
x         = (100, 5)
y         = (100, 5)
result    = np.where(condition, x, y)
Run Code Online (Sandbox Code Playgroud)

这会导致 ValueError,因为“操作数无法一起广播”。据我了解,这个表达式应该可以正常工作,因为我编写了可广播的 x 和 y 的结果。

您能帮我理解为什么条件与 x 和 y 一起广播如此重要吗?

sen*_*rle 2

该条件本质上是一个布尔数组,而不是通用条件。您可以将其视为最终广播形状x和 的遮罩y

如果您这样想,就应该清楚蒙版必须具有与最终输出相同的形状,或者可以广播为相同的形状。

为了说明这一点,这里有一个简单的例子。首先,考虑一个场景,其中我们手动定义了一个 3x3 掩码数组作为condition,并传入两个 3 项数组作为xy,形状适合适当广播:

condition = numpy.array([[0, 1, 1],
                         [1, 0, 1],
                         [0, 0, 1]])
ones = numpy.ones(3)
numpy.where(condition, ones[:, None], ones[None, :] + 1)
Run Code Online (Sandbox Code Playgroud)

结果如下:

>>> numpy.where(condition, ones[:, None], ones[None, :] + 1)
array([[2., 1., 1.],
       [1., 2., 1.],
       [2., 2., 1.]])
Run Code Online (Sandbox Code Playgroud)

由于广播步骤,x并且y它们的行为就像这样定义的:

>>> x
array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]])
>>> y
array([[2., 2., 2.],
       [2., 2., 2.],
       [2., 2., 2.]])
>>> numpy.where(condition, ones[:, None], ones[None, :] + 1)
array([[2., 1., 1.],
       [1., 2., 1.],
       [2., 2., 1.]])
Run Code Online (Sandbox Code Playgroud)

这是 的基本行为where。事实上,您可以传递类似的条件,(x > 5)但不会改变上述内容;(x > 5)成为布尔数组,并且它必须具有与输出相同的形状,否则它必须可广播为该形状。否则, 的行为where将是不明确的。

(顺便说一句,我假设您的问题不是关于为什么形状(100,)(100, 5)(100, 5)不可广播;这似乎是一个不同的问题。)