use*_*299 29 python algebraic-data-types
我知道Python不是Haskell或Ocaml,但这是在Python(2或3)中定义代数数据类型的最佳方法吗?谢谢!
小智 56
这是Brent 答案的 Python 3.10 版本,具有模式匹配和更漂亮的联合类型语法:
from dataclasses import dataclass
@dataclass
class Point:
x: float
y: float
@dataclass
class Circle:
x: float
y: float
r: float
@dataclass
class Rectangle:
x: float
y: float
w: float
h: float
Shape = Point | Circle | Rectangle
def print_shape(shape: Shape):
match shape:
case Point(x, y):
print(f"Point {x} {y}")
case Circle(x, y, r):
print(f"Circle {x} {y} {r}")
case Rectangle(x, y, w, h):
print(f"Rectangle {x} {y} {w} {h}")
print_shape(Point(1, 2))
print_shape(Circle(3, 5, 7))
print_shape(Rectangle(11, 13, 17, 19))
print_shape(4) # mypy type error
Run Code Online (Sandbox Code Playgroud)
您甚至可以执行递归类型:
from __future__ import annotations
from dataclasses import dataclass
@dataclass
class Branch:
value: int
left: Tree
right: Tree
Tree = Branch | None
def contains(tree: Tree, value: int):
match tree:
case None:
return False
case Branch(x, left, right):
return x == value or contains(left, value) or contains(right, value)
tree = Branch(1, Branch(2, None, None), Branch(3, None, Branch(4, None, None)))
assert contains(tree, 1)
assert contains(tree, 2)
assert contains(tree, 3)
assert contains(tree, 4)
assert not contains(tree, 5)
Run Code Online (Sandbox Code Playgroud)
请注意,需要from __future__ import annotations使用尚未定义的类型进行注释。
mypyADT 的详尽检查可以在 Python 3.11+ 中使用或作为旧版本 Python 向后移植的typing.assert_never()一部分来强制执行。typing-extensions
from dataclasses import dataclass
@dataclass
class Point:
x: float
y: float
@dataclass
class Circle:
x: float
y: float
r: float
@dataclass
class Rectangle:
x: float
y: float
w: float
h: float
Shape = Point | Circle | Rectangle
def print_shape(shape: Shape):
match shape:
case Point(x, y):
print(f"Point {x} {y}")
case Circle(x, y, r):
print(f"Circle {x} {y} {r}")
case Rectangle(x, y, w, h):
print(f"Rectangle {x} {y} {w} {h}")
print_shape(Point(1, 2))
print_shape(Circle(3, 5, 7))
print_shape(Rectangle(11, 13, 17, 19))
print_shape(4) # mypy type error
Run Code Online (Sandbox Code Playgroud)
该typing模块提供了Union与 C 不同的 sum 类型。您需要使用 mypy 进行静态类型检查,并且明显缺乏模式匹配,但结合元组(产品类型),这是两种常见的代数类型。
from dataclasses import dataclass
from typing import Union
@dataclass
class Point:
x: float
y: float
@dataclass
class Circle:
x: float
y: float
r: float
@dataclass
class Rectangle:
x: float
y: float
w: float
h: float
Shape = Union[Point, Circle, Rectangle]
def print_shape(shape: Shape):
if isinstance(shape, Point):
print(f"Point {shape.x} {shape.y}")
elif isinstance(shape, Circle):
print(f"Circle {shape.x} {shape.y} {shape.r}")
elif isinstance(shape, Rectangle):
print(f"Rectangle {shape.x} {shape.y} {shape.w} {shape.h}")
print_shape(Point(1, 2))
print_shape(Circle(3, 5, 7))
print_shape(Rectangle(11, 13, 17, 19))
# print_shape(4) # mypy type error
Run Code Online (Sandbox Code Playgroud)
kel*_*oti -3
这是 sum 类型以相对 Pythonic 的方式实现。
import attr
@attr.s(frozen=True)
class CombineMode(object):
kind = attr.ib(type=str)
params = attr.ib(factory=list)
def match(self, expected_kind, f):
if self.kind == expected_kind:
return f(*self.params)
else:
return None
@classmethod
def join(cls):
return cls("join")
@classmethod
def select(cls, column: str):
return cls("select", params=[column])
Run Code Online (Sandbox Code Playgroud)
打开解释器,你会看到熟悉的行为:
>>> CombineMode.join()
CombineMode(kind='join_by_entity', params=[])
>>> CombineMode.select('a') == CombineMode.select('b')
False
>>> CombineMode.select('a') == CombineMode.select('a')
True
>>> CombineMode.select('foo').match('select', print)
foo
Run Code Online (Sandbox Code Playgroud)
注意:@attr.s装饰器来自attrs 库,它实现了__init__、__repr__、 和__eq__,但它也冻结了对象。我包含它是因为它减少了实现大小,但它也广泛可用并且非常稳定。
总和类型有时称为标记联合。这里我使用kindmember来实现标签。附加的每个变体参数是通过列表实现的。在真正的 Python 风格中,这是在输入和输出端采用鸭子类型的,但在内部并未严格执行。
我还包含了一个match执行基本模式匹配的函数。类型安全也是通过鸭子类型实现的,TypeError如果传递的 lambda 函数签名与您尝试匹配的实际变体不一致,则会引发 a 。
list这些求和类型可以与乘积类型 (或)组合tuple,并且仍然保留代数数据类型所需的许多关键功能。
问题
这并不严格限制变体集。