Jor*_*mez 1 python numpy type-hinting mypy python-typing
有什么方法可以将 Numpy 数组键入为通用数组吗?
我目前正在使用 Numpy 1.23.5 和 Python 3.10,并且无法输入以下示例的提示。
import numpy as np
import numpy.typing as npt
E = TypeVar("E") # Should be bounded to a numpy type
def double_arr(arr: npt.NDArray[E]) -> npt.NDArray[E]:
return arr * 2
Run Code Online (Sandbox Code Playgroud)
我的期望是什么
arr = np.array([1, 2, 3], dtype=np.int8)
double_arr(arr) # npt.NDAarray[np.int8]
arr = np.array([1, 2.3, 3], dtype=np.float32)
double_arr(arr) # npt.NDAarray[np.float32]
Run Code Online (Sandbox Code Playgroud)
但我最终遇到以下错误
arr: npt.NDArray[E]
^^^
Could not specialize type "NDArray[ScalarType@NDArray]"
Type "E@double_arr" cannot be assigned to type "generic"
"object*" is incompatible with "generic"
Run Code Online (Sandbox Code Playgroud)
如果我将 E 绑定到 numpy 数据类型 ( np.int8, np.uint8, ...),由于存在多种数据类型,类型检查器无法评估乘法。
查看源代码,似乎用于参数化的泛型类型变量numpy.dtype受(并声明协变)numpy.typing.NDArray限制numpy.generic。因此,任何类型参数都NDArray必须是 的子类型numpy.generic,而您的类型变量是无界的。这应该有效:
from typing import TypeVar
import numpy as np
from numpy.typing import NDArray
E = TypeVar("E", bound=np.generic, covariant=True)
def double_arr(arr: NDArray[E]) -> NDArray[E]:
return arr * 2
Run Code Online (Sandbox Code Playgroud)
但还有另一个问题,我认为这个问题在于 numpy 存根不足。本期展示了一个例子。重载的操作数(魔术)方法会以__mul__某种方式破坏类型。我现在只是粗略地看了一下代码,所以我不知道缺少什么。但mypy仍然会抱怨该代码中的最后一行:
错误:从声明为返回“ndarray[Any, dtype[E]]”的函数返回 Any [no-any-return]
错误: * ("ndarray[Any, dtype[E]]" 和 "int") [运算符] 不支持的操作数类型
现在的解决方法是使用函数而不是操作数(通过 dunder 方法)。在这种情况下,使用numpy.multiply而不是*解决问题:
error: Returning Any from function declared to return "ndarray[Any, dtype[E]]" [no-any-return]
error: Unsupported operand types for * ("ndarray[Any, dtype[E]]" and "int") [operator]
不再mypy抱怨,类型揭示如下:
from typing import TypeVar
import numpy as np
from numpy.typing import NDArray
E = TypeVar("E", bound=np.generic, covariant=True)
def double_arr(arr: NDArray[E]) -> NDArray[E]:
return np.multiply(arr, 2)
a = np.array([1, 2, 3], dtype=np.int8)
reveal_type(double_arr(a))
Run Code Online (Sandbox Code Playgroud)
值得关注该操作数问题,甚至可能Unsupported operand types for *单独报告特定错误。我还没有在问题跟踪器中找到它。
PS:或者,您可以使用*运算符并添加特定的 type: ignore. 这样你就会注意到,如果/一旦注释错误最终被 numpy 修复,因为mypy抱怨严格模式下未使用的忽略指令。
numpy.ndarray[Any, numpy.dtype[numpy.signedinteger[numpy._typing._8Bit]]]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
3331 次 |
| 最近记录: |