如何在子类中键入注释覆盖的方法?

Mar*_*nen 5 python types python-3.x

假设我已经有一个带有类型注释的方法:

class Shape:
    def area(self) -> float:
        raise NotImplementedError
Run Code Online (Sandbox Code Playgroud)

然后我将多次子类化:

class Circle:
    def area(self) -> float:
        return math.pi * self.radius ** 2

class Rectangle:
    def area(self) -> float:
        return self.height * self.width
Run Code Online (Sandbox Code Playgroud)

正如你所看到的,我复制-> float了很多。假设我有 10 个不同的形状,有多种方法,其中一些也包含参数。有没有办法从父类中“复制”注释,类似于functools.wraps()文档字符串?

Ilj*_*ilä 4

这可能会起作用,尽管我肯定会错过边缘情况,例如附加参数:

from functools import partial, update_wrapper


def annotate_from(f):
    return partial(update_wrapper,
                   wrapped=f,
                   assigned=('__annotations__',),
                   updated=())
Run Code Online (Sandbox Code Playgroud)

它将分配“包装器”函数的__annotations__属性f.__annotations__(请记住,它不是副本)。

根据文档,该update_wrapper函数的默认值已经包含,但我可以理解为什么您不希望从wrapped__annotations__分配所有其他属性。

有了这个,您就可以将您的Circle和定义Rectangle

class Circle:
    @annotate_from(Shape.area)
    def area(self):
        return math.pi * self.radius ** 2

class Rectangle:
    @annotate_from(Shape.area)
    def area(self):
        return self.height * self.width
Run Code Online (Sandbox Code Playgroud)

和结果

In [82]: Circle.area.__annotations__
Out[82]: {'return': builtins.float}

In [86]: Rectangle.area.__annotations__
Out[86]: {'return': builtins.float}
Run Code Online (Sandbox Code Playgroud)

作为副作用,您的方法将具有一个属性__wrapped__,在本例中该属性将指向Shape.area


可以使用类装饰器来实现处理重写方法的不太标准的方法(如果您可以调用上面update_wrapper标准的使用):

from inspect import getmembers, isfunction, signature


def override(f):
    """
    Mark method overrides.
    """
    f.__override__ = True
    return f


def _is_method_override(m):
    return isfunction(m) and getattr(m, '__override__', False)


def annotate_overrides(cls):
    """
    Copy annotations of overridden methods.
    """
    bases = cls.mro()[1:]
    for name, method in getmembers(cls, _is_method_override):
        for base in bases:
            if hasattr(base, name):
                break

        else:
            raise RuntimeError(
                    'method {!r} not found in bases of {!r}'.format(
                            name, cls))

        base_method = getattr(base, name)
        method.__annotations__ = base_method.__annotations__.copy()

    return cls
Run Code Online (Sandbox Code Playgroud)

进而:

@annotate_overrides
class Rectangle(Shape):
    @override
    def area(self):
        return self.height * self.width
Run Code Online (Sandbox Code Playgroud)

同样,这不会处理带有附加参数的重写方法。