用你班级的 __mul__ 覆盖其他 __rmul__

Nic*_*ger 6 python numpy

在Python中,您的类的方法是否可以__rmul__覆盖另一个类的__mul__方法,而不对另一个类进行更改?

出现这个问题是因为我正在为某种类型的线性运算符编写一个类,并且我希望它能够使用乘法语法来乘以 numpy 数组。这是说明该问题的最小示例:

import numpy as np    

class AbstractMatrix(object):
    def __init__(self):
        self.data = np.array([[1, 2],[3, 4]])

    def __mul__(self, other):
        return np.dot(self.data, other)

    def __rmul__(self, other):
        return np.dot(other, self.data)
Run Code Online (Sandbox Code Playgroud)

左乘法效果很好:

In[11]: A = AbstractMatrix()
In[12]: B = np.array([[4, 5],[6, 7]])
In[13]: A*B
Out[13]: 
array([[16, 19],
       [36, 43]])
Run Code Online (Sandbox Code Playgroud)

但右乘默认为np.ndarrays 版本,它将数组分割并逐个元素执行乘法(这不是我们想要的):

In[14]: B*A
Out[14]: 
array([[array([[ 4,  8],
       [12, 16]]),
        array([[ 5, 10],
       [15, 20]])],
       [array([[ 6, 12],
       [18, 24]]),
        array([[ 7, 14],
       [21, 28]])]], dtype=object)
Run Code Online (Sandbox Code Playgroud)

在这种情况下,我怎样才能让它__rmul__在原始(未分割)数组上调用我自己的类?

欢迎针对 numpy 数组特定情况的答案,但我也对覆盖另一个无法修改的第三方类的方法的一般概念感兴趣。

MSe*_*ert 5

NumPy尊重您的方法的最简单方法__rmul__是设置__array_priority__

class AbstractMatrix(object):
    def __init__(self):
        self.data = np.array([[1, 2],[3, 4]])

    def __mul__(self, other):
        return np.dot(self.data, other)

    def __rmul__(self, other):
        return np.dot(other, self.data)

    __array_priority__ = 10000

A = AbstractMatrix()
B = np.array([[4, 5],[6, 7]])
Run Code Online (Sandbox Code Playgroud)

这就像预期的那样工作。

>>> B*A
array([[19, 28],
       [27, 40]])
Run Code Online (Sandbox Code Playgroud)

问题是NumPy不尊重Python的“数字”数据模型。如果 numpy 数组是第一个参数并且numpy.ndarray.__mul__不可能,那么它会尝试如下操作:

result = np.empty(B.shape, dtype=object)
for idx, item in np.ndenumerate(B):
    result[idx] = A.__rmul__(item)
Run Code Online (Sandbox Code Playgroud)

但是,如果第二个参数有一个__array_priority__并且它高于第一个参数,那么它才真正使用:

A.__rmul__(B)
Run Code Online (Sandbox Code Playgroud)

然而,从 Python 3.5 ( PEP-465 ) 开始,就有了可以利用矩阵乘法的@( ) 运算符:__matmul__

>>> A = np.array([[1, 2],[3, 4]])
>>> B = np.array([[4, 5],[6, 7]])
>>> B @ A
array([[19, 28],
       [27, 40]])
Run Code Online (Sandbox Code Playgroud)