numpy ndarray的子​​类不能按预期工作

Ken*_*n T 6 python numpy

"大家好.

我发现子类化ndarray时有一种奇怪的行为.

import numpy as np

class fooarray(np.ndarray):
    def __new__(cls, input_array, *args, **kwargs):
        obj = np.asarray(input_array).view(cls)
        return obj

    def __init__(self, *args, **kwargs):
        return

    def __array_finalize__(self, obj):
        return

a=fooarray(np.random.randn(3,5))
b=np.random.randn(3,5)

a_sum=np.sum(a,axis=0,keepdims=True)
b_sum=np.sum(b,axis=0, keepdims=True)

print a_sum.ndim #1
print b_sum.ndim #2
Run Code Online (Sandbox Code Playgroud)

如您所见,该keepdims参数对我的子类不起作用fooarray.它失去了一个轴.我怎能不避免这个问题?或者更一般地说,我怎样才能正确地将numpy ndarray子类化?

unu*_*tbu 5

np.sum可以接受各种对象作为输入:不仅是 ndarray,还可以是列表、生成器、np.matrixs 等。该keepdims参数显然对于列表或生成器没有意义。它也不适合np.matrix实例,因为np.matrixs 始终具有二维。如果您查看调用签名,np.matrix.sum您会发现它的sum方法没有keepdims参数:

Definition: np.matrix.sum(self, axis=None, dtype=None, out=None)
Run Code Online (Sandbox Code Playgroud)

因此 的某些子类ndarray可能具有sum没有keepdims参数的方法。不幸的是,这违反了里氏替换原则,也是您遇到的陷阱的根源。

现在,如果您查看 的源代码np.sum,您会发现它是一个委托函数,它尝试根据第一个参数的类型确定要执行的操作。

如果第一个参数的类型不是ndarray,则会删除该keepdims参数。这样做是因为传递 keepdims 参数np.matrix.sum会引发异常。

因此,因为np.sum尝试以最通用的方式进行委托,而不是对 ndarray 的子类可能采用哪些参数做出任何假设,所以它keepdims在传递fooarray.

解决方法是不使用np.sum,而是调用a.sum。无论如何,这更直接,因为np.sum这只是一个委托功能。

import numpy as np


class fooarray(np.ndarray):
    def __new__(cls, input_array, *args, **kwargs):
        obj = np.asarray(input_array, *args, **kwargs).view(cls)
        return obj

a = fooarray(np.random.randn(3, 5))
b = np.random.randn(3, 5)

a_sum = a.sum(axis=0, keepdims=True)
b_sum = np.sum(b, axis=0, keepdims=True)

print(a_sum.ndim)  # 2
print(b_sum.ndim)  # 2
Run Code Online (Sandbox Code Playgroud)