有什么方法可以将 Numpy 数组键入为通用数组吗?
我目前正在使用 Numpy 1.23.5 和 Python 3.10,并且无法输入以下示例的提示。
import numpy as np
import numpy.typing as npt
E = TypeVar("E") # Should be bounded to a numpy type
def double_arr(arr: npt.NDArray[E]) -> npt.NDArray[E]:
return arr * 2
Run Code Online (Sandbox Code Playgroud)
我的期望是什么
arr = np.array([1, 2, 3], dtype=np.int8)
double_arr(arr) # npt.NDAarray[np.int8]
arr = np.array([1, 2.3, 3], dtype=np.float32)
double_arr(arr) # npt.NDAarray[np.float32]
Run Code Online (Sandbox Code Playgroud)
但我最终遇到以下错误
arr: npt.NDArray[E]
^^^
Could not specialize type "NDArray[ScalarType@NDArray]"
Type "E@double_arr" cannot be assigned to type "generic"
"object*" is incompatible with "generic"
Run Code Online (Sandbox Code Playgroud)
如果我将 E …