有没有办法输入提示 pandas 对象的索引?

Eri*_*rin 6 protocols typing dataframe pandas

我想输入提示:pandas 数据框必须有一个日期时间索引。我希望可能有某种方法可以通过协议来做到这一点,但看起来没有。本着这样的精神:

class TSFrame(Protocol):
    index: pd.DatetimeIndex

def test(df: TSFrame):
    # Do stuff with df.index.methods_supported_by_dtidx_only
    pass

nontsdf = pd.DataFrame()
tsdf = pd.DataFrame(index=pd.DatetimeIndex(pd.date_range("2022-01-01", "2022-01-02")))
test(nontsdf)  # goal is for my interpreter to complain here
test(tsdf)  # and not complain here
Run Code Online (Sandbox Code Playgroud)

相反,我的口译员在这两种情况下都抱怨。令人困惑的是,如果我在泛型类上创建类似的测试,但类型提示为 int,则两种情况都不会抱怨。

class IntWanted(Protocol):
    var: int

class TestClass:
    def __init__(self, var: Any) -> None:
        self.var = var

def foo(a: IntWanted) -> int:
    return a.var

good = TestClass(1)
bad = TestClass("x")
foo(good)
foo(bad) 
Run Code Online (Sandbox Code Playgroud)

我能想到的处理这些时间序列数据帧的其他方法:

  1. 子类化数据帧并添加索引是否为日期时间索引的验证。将我拥有的每个 df 转换为此类的实例,并在各处键入提示该类。这是否可以解决 mypy 知道其索引具有 dtidx 属性的问题?我想不是。例如
class TSFrame(Protocol):
    index: pd.DatetimeIndex

def test(df: TSFrame):
    # Do stuff with df.index.methods_supported_by_dtidx_only
    pass

nontsdf = pd.DataFrame()
tsdf = pd.DataFrame(index=pd.DatetimeIndex(pd.date_range("2022-01-01", "2022-01-02")))
test(nontsdf)  # goal is for my interpreter to complain here
test(tsdf)  # and not complain here
Run Code Online (Sandbox Code Playgroud)
  1. 创建一个全新的对象,将 df 作为一个属性,定义一个索引,该索引是从输入 df 的索引创建的 datetimeindex,以便 mypy 知道该索引的类型。这感觉很沉重。
class IntWanted(Protocol):
    var: int

class TestClass:
    def __init__(self, var: Any) -> None:
        self.var = var

def foo(a: IntWanted) -> int:
    return a.var

good = TestClass(1)
bad = TestClass("x")
foo(good)
foo(bad) 
Run Code Online (Sandbox Code Playgroud)

想法表示赞赏。

Cor*_*ien 2

您可以使用pandera( 和pandas-stub) 来做几乎任何您想做的事情。

  1. pip install pandera[mypy]
  2. 创建一个mypy.ini文件:
[mypy]
plugins = pandera.mypy
Run Code Online (Sandbox Code Playgroud)

demo.py

import pandera as pa
import pandas as pd
import numpy as np
from pandera.typing import Index, DataFrame, Series

class TSFrame(pa.DataFrameModel):
    idx: Index[pa.Timestamp] = pa.Field(check_name=False)

@pa.check_types  # at runtime
def test(df: DataFrame[TSFrame]):  # at compile time
    pass

nontsdf = pd.DataFrame()
tsdf = DataFrame[TSFrame](index=pd.DatetimeIndex(pd.date_range("2022-01-01", "2022-01-02")))
test(nontsdf)
test(tsdf)
Run Code Online (Sandbox Code Playgroud)

用法:

[...]$ mypy demo.py
demo1.py:14: error: Argument 1 to "test" has incompatible type "pandas.core.frame.DataFrame"; expected "pandera.typing.pandas.DataFrame[TSFrame]"  [arg-type]
Found 1 error in 1 file (checked 1 source file)

[...]$ python demo.py
...
pandera.errors.SchemaError: error in check_types decorator of function 'test': expected series 'None' to have type datetime64[ns], got int64
Run Code Online (Sandbox Code Playgroud)

更多信息: