具有特定形状和数据类型的 Numpy 类型

Bod*_*rop 29 numpy typing python-typing

目前,我正在尝试更多地使用 numpy 类型来使我的代码更清晰,但是我已经达到了目前无法覆盖的限制。

是否可以指定特定的形状以及相应的数据类型?例子:

Shape=(4,)
datatype= np.int32
Run Code Online (Sandbox Code Playgroud)

到目前为止,我的尝试如下所示(但都只是抛出错误):

第一次尝试:

import numpy as np

def foo(x: np.ndarray[(4,), np.dtype[np.int32]]):
...
result -> 'numpy._DTypeMeta' object is not subscriptable
Run Code Online (Sandbox Code Playgroud)

第二次尝试:

import numpy as np
import numpy.typing as npt

def foo(x: npt.NDArray[(4,), np.int32]):
...
result -> Too many arguments for numpy.ndarray[typing.Any, numpy.dtype[+ScalarType]]
Run Code Online (Sandbox Code Playgroud)

另外,不幸的是,我在文档中找不到有关它的任何信息,或者只有当我按照文档记录的方式实现它时才会出现错误。

R H*_*R H 39

目前,numpy.typing.NDArray仅接受 dtype,如下所示:numpy.typing.NDArray[numpy.int32]。不过你有一些选择。

使用typing.Annotated

typing.Annotated允许您为类型创建别名并将一些额外信息与其捆绑在一起。

在某些情况下,my_types.py您可以写出您想要暗示的形状的所有变体:

from typing import Annotated, Literal, TypeVar
import numpy as np
import numpy.typing as npt


DType = TypeVar("DType", bound=np.generic)

Array4 = Annotated[npt.NDArray[DType], Literal[4]]
Array3x3 = Annotated[npt.NDArray[DType], Literal[3, 3]]
ArrayNxNx3 = Annotated[npt.NDArray[DType], Literal["N", "N", 3]]
Run Code Online (Sandbox Code Playgroud)

然后在 中foo.py,您可以提供 numpy dtype 并将它们用作类型提示:

import numpy as np
from my_types import Array4


def foo(arr: Array4[np.int32]):
    assert arr.shape == (4,)
Run Code Online (Sandbox Code Playgroud)

MyPy 将识别arr为 anp.ndarray并对其进行检查。形状检查只能在运行时完成,就像本例中的assert.

如果您不喜欢这个断言,您可以发挥您的创造力来定义一个函数来为您进行检查。

def assert_match(arr, array_type):
    hinted_shape = array_type.__metadata__[0].__args__
    hinted_dtype_type = array_type.__args__[0].__args__[1]
    hinted_dtype = hinted_dtype_type.__args__[0]
    assert np.issubdtype(arr.dtype, hinted_dtype), "DType does not match"
    assert arr.shape == hinted_shape, "Shape does not match"


assert_match(some_array, Array4[np.int32])
Run Code Online (Sandbox Code Playgroud)

使用nptyping

另一种选择是使用第 3 方库nptyping(是的,我是作者)。

你会放弃,my_types.py因为它不再有任何用处。

foo.py会变成这样:

from nptyping import NDArray, Shape, Int32


def foo(arr: NDArray[Shape["4"], Int32]):
    assert isinstance(arr, NDArray[Shape["4"], Int32])
Run Code Online (Sandbox Code Playgroud)

使用beartype+typing.Annotated

还有另一个名为beartype您可以使用的第三方库。它可以采用该方法的变体typing.Annotated,并为您进行运行时检查。

您可以my_types.py使用类似以下内容恢复您的内容:

from beartype import beartype
from beartype.vale import Is
from typing import Annotated
import numpy as np


Int32Array4 = Annotated[np.ndarray, Is[lambda array:
    array.shape == (4,) and np.issubdtype(array.dtype, np.int32)]]
Int32Array3x3 = Annotated[np.ndarray, Is[lambda array:
    array.shape == (3,3) and np.issubdtype(array.dtype, np.int32)]]
Run Code Online (Sandbox Code Playgroud)

foo.py会变成:

import numpy as np
from beartype import beartype
from my_types import Int32Array4 


@beartype
def foo(arr: Int32Array4):
    ...  # Runtime type checked by beartype.
Run Code Online (Sandbox Code Playgroud)

使用beartype+nptyping

您还可以堆叠这两个库。

my_types.py可以再次删除,您的内容foo.py将变成类似以下内容:

from nptyping import NDArray, Shape, Int32
from beartype import beartype


@beartype
def foo(arr: NDArray[Shape["4"], Int32]):
    ...  # Runtime type checked by beartype.
Run Code Online (Sandbox Code Playgroud)

  • [文档](https://numpy.org/devdocs/reference/typing.html#numpy.typing.NDArray) 说 numpy.typing.NDArray = numpy.ndarray[typing.Any, numpy.dtype[+_ScalarType_co] ]`。如果“NDArray”只接受数据类型,第一个参数“typing.Any”是什么? (4认同)
  • 它说:*“可以在运行时用于输入具有给定数据类型和未指定形状的数组。”*,这意味着无法指定形状。如果您查看[代码](https://github.com/numpy/numpy/blob/094416f7433a0bc077e472e801fe36613318c01f/numpy/__init__.pyi#L1477),您会注意到第一个参数称为“_ShapeType”。对我来说,这表明 numpy 为未来支持提示形状敞开了大门。 (2认同)