使用 ArrayLike 时出现 Mypy 错误

Obl*_*mov 3 python numpy mypy python-typing

我不明白应该如何在代码中使用 ArrayLike。如果检查 mypy,当我尝试在不调用强制转换的情况下使用变量进行任何操作时,我会不断收到错误。我正在尝试定义与 ndarray 以及常规列表一起使用的函数签名。

例如下面的代码

import numpy.typing as npt
import numpy as np

from typing import Any

def f(a: npt.ArrayLike) -> int:
    return len(a)

def g(a: npt.ArrayLike) -> Any:
    return a[0]

print(f(np.array([0, 1])), g(np.array([0, 1])))
print(f([0, 1]), g([0, 1]))
Run Code Online (Sandbox Code Playgroud)

给我 f() 和 g() 的这些错误:

Argument 1 to "len" has incompatible type "Union[_SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]]"; expected "Sized"  [arg-type]

Value of type "Union[_SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]]" is not indexable  [index]
Run Code Online (Sandbox Code Playgroud)

Dan*_*erg 5

的目的numpy.typing.ArrayLike是能够注释

可以强制转换为ndarray.

考虑到这个目的,他们将类型定义为以下联合:

Union[
    _SupportsArray[dtype[Any]],
    _NestedSequence[_SupportsArray[dtype[Any]]],
    bool,
    int,
    float,
    complex,
    str,
    bytes,
    _NestedSequence[Union[bool, int, float, complex, str, bytes]]
]
Run Code Online (Sandbox Code Playgroud)

_SupportsArray只是一个带有__array__方法的协议。它既不需要实现__len__(与函数一起使用len)也不__getitem__需要实现(用于索引)。

_NestedSequence是一个限制性更强的协议,实际上需要__len____getitem__

但这段代码的问题在于参数注释是union

import numpy.typing as npt

...

def f(a: npt.ArrayLike) -> int:
    return len(a)
Run Code Online (Sandbox Code Playgroud)

所以a 可能是一个支持 的类似序列的对象__len__,但它也可能只是一个支持的对象__array__而没有其他。它甚至可能只是一个int例子(再次参见工会)。因此,该调用len(a)是不安全的。

同样,这里的项目访问不是类型安全的,因为a可能无法实现__getitem__

...

def g(a: npt.ArrayLike) -> Any:
    return a[0]
Run Code Online (Sandbox Code Playgroud)

所以它不适合你的原因是它不适合用作 numpy 数组或其他序列的注释;它旨在用于可以转换为 numpy 数组的东西。


如果你想注释你的函数fg获取列表和 numpy 数组,你可以只使用listNDArraylike的并集list[Any] | npt.NDArray[Any]

如果您想要更广泛的注释来容纳任何具有__len__和 的类型__getitem__,您需要定义自己的协议

from typing import Any, Protocol, TypeVar

import numpy as np

T = TypeVar("T", covariant=True)


class SequenceLike(Protocol[T]):
    def __len__(self) -> int: ...
    def __getitem__(self, item: int) -> T: ...


def f(a: SequenceLike[Any]) -> int:
    return len(a)


def g(a: SequenceLike[T]) -> T:
    return a[0]


print(f(np.array([0, 1])), g(np.array([0, 1])))
print(f([0, 1]), g([0, 1]))
Run Code Online (Sandbox Code Playgroud)

更准确地说,__getitem__可能还应该获取slice对象,但重载对你来说可能有点过分了。