Python Numpy中的Array和__rmul__运算符

Exo*_*odd 8 python arrays numpy

在一个项目中,我创建了一个类,我需要在这个新类和一个真实矩阵之间进行操作,所以我重载了这个__rmul__函数

class foo(object):

    aarg = 0

    def __init__(self):
        self.aarg = 1


    def __rmul__(self,A):
        print(A)
        return 0

    def __mul__(self,A):
        print(A)
        return 0
Run Code Online (Sandbox Code Playgroud)

但是当我打电话给它时,结果并不是我的预期

A = [[i*j for i in np.arange(2)  ] for j in np.arange(3)]
A = np.array(A)
R = foo()
C =  A * R
Run Code Online (Sandbox Code Playgroud)

输出:

0
0
0
1
0
2
Run Code Online (Sandbox Code Playgroud)

似乎该函数被调用了6次,每个元素一次.

相反,该__mul__功能非常有效

C = R * A
Run Code Online (Sandbox Code Playgroud)

输出:

[[0 0]
 [0 1]
 [0 2]]
Run Code Online (Sandbox Code Playgroud)

如果A不是数组,而只是列表列表,则两者都可以正常工作

A = [[i*j for i in np.arange(2)  ] for j in np.arange(3)]
R = foo()
C =  A * R
C = R * A
Run Code Online (Sandbox Code Playgroud)

产量

[[0, 0], [0, 1], [0, 2]]
[[0, 0], [0, 1], [0, 2]]
Run Code Online (Sandbox Code Playgroud)

我真的希望我的__rmul__函数也可以在数组上工作(我的原始乘法函数不是可交换的).我该如何解决?

Bak*_*riu 6

行为是预期的.

首先,您必须了解如何x*y实际执行操作.python解释器将首先尝试计算x.__mul__(y).如果这个调用返回NotImplemented它会尝试计算y.__rmul__(x). 除了y是类型的正确子类x,在这种情况下,解释器将首先考虑y.__rmul__(x)x.__mul__(y).

现在发生的事情是numpy根据他是否认为参数是标量或数组来区别对待参数.

处理数组时*,逐元素乘法,而标量乘法将数组的所有条目乘以给定的标量.

在你的情况下foo()被numpy视为标量,因此numpy乘以数组的所有元素foo.此外,由于numpy不知道foo它返回数组的类型dtype=object,因此返回的对象是:

array([[0, 0],
       [0, 0],
       [0, 0]], dtype=object)
Run Code Online (Sandbox Code Playgroud)

注:numpy的数组也不会返回NotImplemented当您尝试计算产品,所以解释称numpy的的阵列__mul__的方法,正如我们所说的执行标量乘法.此时numpy将尝试将数组的每个条目乘以"标量" foo(),这里是__rmul__调用方法的地方,因为数组中的数字在使用参数调用NotImplemented时返回.__mul__foo

显然,如果您将参数的顺序更改为初始乘法__mul__,则会立即调用您的方法,并且您没有遇到任何麻烦.

因此,要回答您的问题,处理此问题的一种方法是foo继承ndarray,以便第二条规则适用:

class foo(np.ndarray):
    def __new__(cls):
       # you must implement __new__
    # code as before
Run Code Online (Sandbox Code Playgroud)

但是请注意,子类化ndarray并不简单.此外,你可能有其他副作用,因为现在你的班级是ndarray.