使用 mypy 对 NumPy ndarray 进行特定类型注释

Joh*_*ohn 8 numpy python-3.x mypy numpy-ndarray

NumPy 1.20 中添加了对类型注释的支持。我试图弄清楚如何告诉 mypy 数组填充了特定类型的元素,注释np.ndarray[np.dcomplex]给出了 mypy error "ndarray" expects no type arguments, but 1 given

编辑:这个问题与numpy.ndarray 的类型提示/注释(PEP 484)不同,因为这个问题是在 4 年前提出的,当时没有任何官方支持类型提示。我问的是什么是官方的 方法,现在numpy 1.20实际上支持类型提示。https://numpy.org/doc/stable/reference/typing.html#module-numpy.typing上的文档指出,那里的最佳答案似乎只说你不应该用类型提示做的事情,而不是解释什么你应该做的。

tgp*_*fer 32

您正在寻找的是该类numpy.typing.NDArrayhttps://numpy.org/doc/stable/reference/typing.html#numpy.typing.NDArray

numpy.typing.NDArray[A]是 的别名numpy.ndarray[Any, numpy.dtype[A]]

import numpy as np
import numpy.typing as npt

a: npt.NDArray[np.complex64] = np.zeros((3, 3), dtype=np.complex64)
# reveal_type(a)  # -> numpy.ndarray[Any, numpy.dtype[numpy.complexfloating[numpy.typing._32Bit, numpy.typing._32Bit]]]
print(a)
Run Code Online (Sandbox Code Playgroud)

印刷

[[0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j]]
Run Code Online (Sandbox Code Playgroud)

请注意,即使您注释anpt.NDArray[np.complex64],您仍然需要确保将匹配传递dtype给右侧的工厂。

[[0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j]]
Run Code Online (Sandbox Code Playgroud)

也通过了 mypy 检查。