非常大的集合的 SQLAlchemy 集合成员资格

use*_*114 6 python postgresql sqlalchemy

我的 SQL 查询可以非常简单地写为:

result = session.query(Table).filter(Table.my_key._in(key_set))
Run Code Online (Sandbox Code Playgroud)

整数my_key列已建立索引(主键),但key_set实际上可能非常大,具有数千万个值。

对于如此大的集合进行过滤,推荐的 SQLAlchemy 模式是什么?

有没有比行人更高效的内置东西:

result = [session.query(Table).get(key) for key in key_set]
Run Code Online (Sandbox Code Playgroud)

Ilj*_*ilä 7

在这种极端情况下,您最好首先考虑推荐的 SQL 解决方案是什么,然后在 SQLAlchemy \xe2\x80\x93 中实现它,甚至使用原始 SQL(如果需要)。其中一种解决方案是为数据创建临时表key_set并填充它。

\n\n

为了测试类似您的设置的内容,我创建了以下模型

\n\n
class Table(Base):\n    __tablename__ = \'mytable\'\n    my_key = Column(Integer, primary_key=True)\n
Run Code Online (Sandbox Code Playgroud)\n\n

并用 20,000,000 行填充它:

\n\n
In [1]: engine.execute("""\n   ...:     insert into mytable\n   ...:     select generate_series(1, 20000001)\n   ...:     """)\n
Run Code Online (Sandbox Code Playgroud)\n\n

我还创建了一些帮助程序来测试临时表、填充和查询的不同组合。请注意,查询使用 Core 表,以便绕过 ORM 及其机制 \xe2\x80\x93,无论如何,对计时的贡献将是恒定的:

\n\n
# testdb is just your usual SQLAlchemy imports, and some\n# preconfigured engine options.\nfrom testdb import *\nfrom sqlalchemy.ext.compiler import compiles\nfrom sqlalchemy.sql.expression import Executable, ClauseElement\nfrom io import StringIO\nfrom itertools import product\n\nclass Table(Base):\n    __tablename__ = "mytable"\n    my_key = Column(Integer, primary_key=True)\n\ndef with_session(f):\n    def wrapper(*a, **kw):\n        session = Session(bind=engine)\n        try:\n            return f(session, *a, **kw)\n\n        finally:\n            session.close()\n    return wrapper\n\ndef all(_, query):\n    return query.all()\n\ndef explain(analyze=False):\n    def cont(session, query):\n        results = session.execute(Explain(query.statement, analyze))\n        return [l for l, in results]\n\n    return cont\n\nclass Explain(Executable, ClauseElement):\n    def __init__(self, stmt, analyze=False):\n        self.stmt = stmt\n        self.analyze = analyze\n\n@compiles(Explain)\ndef visit_explain(element, compiler, **kw):\n    stmt = "EXPLAIN "\n\n    if element.analyze:\n        stmt += "ANALYZE "\n\n    stmt += compiler.process(element.stmt, **kw)\n    return stmt\n\ndef create_tmp_tbl_w_insert(session, key_set, unique=False):\n    session.execute("CREATE TEMPORARY TABLE x (k INTEGER NOT NULL)")\n    x = table("x", column("k"))\n    session.execute(x.insert().values([(k,) for k in key_set]))\n\n    if unique:\n        session.execute("CREATE UNIQUE INDEX ON x (k)")\n\n    session.execute("ANALYZE x")\n    return x\n\ndef create_tmp_tbl_w_copy(session, key_set, unique=False):\n    session.execute("CREATE TEMPORARY TABLE x (k INTEGER NOT NULL)")\n    # This assumes that the string representation of the Python values\n    # is a valid representation for Postgresql as well. If this is not\n    # the case, `cur.mogrify()` should be used.\n    file = StringIO("".join([f"{k}\\n" for k in key_set]))\n    # HACK ALERT, get the DB-API connection object\n    with session.connection().connection.connection.cursor() as cur:\n        cur.copy_from(file, "x")\n\n    if unique:\n        session.execute("CREATE UNIQUE INDEX ON x (k)")\n\n    session.execute("ANALYZE x")\n    return table("x", column("k"))\n\ntmp_tbl_factories = {\n    "insert": create_tmp_tbl_w_insert,\n    "insert (uniq)": lambda session, key_set: create_tmp_tbl_w_insert(session, key_set, unique=True),\n    "copy": create_tmp_tbl_w_copy,\n    "copy (uniq)": lambda session, key_set: create_tmp_tbl_w_copy(session, key_set, unique=True),\n}\n\nquery_factories = {\n    "in": lambda session, _, x: session.query(Table.__table__).\n        filter(Table.my_key.in_(x.select().as_scalar())),\n    "exists": lambda session, _, x: session.query(Table.__table__).\n        filter(exists().where(x.c.k == Table.my_key)),\n    "join": lambda session, _, x: session.query(Table.__table__).\n        join(x, x.c.k == Table.my_key)\n}\n\ntests = {\n    "test in": (\n        lambda _s, _ks: None,\n        lambda session, key_set, _: session.query(Table.__table__).\n            filter(Table.my_key.in_(key_set))\n    ),\n    "test in expanding": (\n        lambda _s, _kw: None,\n        lambda session, key_set, _: session.query(Table.__table__).\n            filter(Table.my_key.in_(bindparam(\'key_set\', key_set, expanding=True)))\n    ),\n    **{\n        f"test {ql} w/ {tl}": (tf, qf)\n        for (tl, tf), (ql, qf)\n        in product(tmp_tbl_factories.items(), query_factories.items())\n    }\n}\n\n@with_session\ndef run_test(session, key_set, tmp_tbl_factory, query_factory, *, cont=all):\n    x = tmp_tbl_factory(session, key_set)\n    return cont(session, query_factory(session, key_set, x))\n
Run Code Online (Sandbox Code Playgroud)\n\n

对于小键集,您拥有的简单IN查询与其他查询一样快,但使用key_set100,000 个更复杂的解决方案开始获胜:

\n\n
In [10]: for test, steps in tests.items():\n    ...:     print(f"{test:<28}", end=" ")\n    ...:     %timeit -r2 -n2 run_test(range(100000), *steps)\n    ...:     \ntest in                      2.21 s \xc2\xb1 7.31 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\ntest in expanding            630 ms \xc2\xb1 929 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\ntest in w/ insert            1.83 s \xc2\xb1 3.73 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\ntest exists w/ insert        1.83 s \xc2\xb1 3.99 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\ntest join w/ insert          1.86 s \xc2\xb1 3.76 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\ntest in w/ insert (uniq)     1.87 s \xc2\xb1 6.67 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\ntest exists w/ insert (uniq) 1.84 s \xc2\xb1 125 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\ntest join w/ insert (uniq)   1.85 s \xc2\xb1 2.8 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\ntest in w/ copy              246 ms \xc2\xb1 1.18 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\ntest exists w/ copy          243 ms \xc2\xb1 2.31 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\ntest join w/ copy            258 ms \xc2\xb1 3.05 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\ntest in w/ copy (uniq)       261 ms \xc2\xb1 1.39 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\ntest exists w/ copy (uniq)   267 ms \xc2\xb1 8.24 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\ntest join w/ copy (uniq)     264 ms \xc2\xb1 1.16 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 2 loops each)\n
Run Code Online (Sandbox Code Playgroud)\n\n

将其提高key_set到 1,000,000:

\n\n
In [11]: for test, steps in tests.items():\n    ...:     print(f"{test:<28}", end=" ")\n    ...:     %timeit -r2 -n1 run_test(range(1000000), *steps)\n    ...:     \ntest in                      23.8 s \xc2\xb1 158 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\ntest in expanding            6.96 s \xc2\xb1 3.02 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\ntest in w/ insert            19.6 s \xc2\xb1 79.3 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\ntest exists w/ insert        20.1 s \xc2\xb1 114 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\ntest join w/ insert          19.5 s \xc2\xb1 7.93 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\ntest in w/ insert (uniq)     19.5 s \xc2\xb1 45.4 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\ntest exists w/ insert (uniq) 19.6 s \xc2\xb1 73.6 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\ntest join w/ insert (uniq)   20 s \xc2\xb1 57.5 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\ntest in w/ copy              2.53 s \xc2\xb1 49.9 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\ntest exists w/ copy          2.56 s \xc2\xb1 1.96 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\ntest join w/ copy            2.61 s \xc2\xb1 26.8 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\ntest in w/ copy (uniq)       2.63 s \xc2\xb1 3.79 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\ntest exists w/ copy (uniq)   2.61 s \xc2\xb1 916 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\ntest join w/ copy (uniq)     2.6 s \xc2\xb1 5.31 ms per loop (mean \xc2\xb1 std. dev. of 2 runs, 1 loop each)\n
Run Code Online (Sandbox Code Playgroud)\n\n

10,000,000 个密钥集,COPY仅解决方案,因为其他人吃掉了我所有的 RAM,并且在被杀死之前进行了交换,暗示他们永远不会在这台机器上完成:

\n\n
In [12]: for test, steps in tests.items():\n    ...:     if "copy" in test:\n    ...:         print(f"{test:<28}", end=" ")\n    ...:         %timeit -r1 -n1 run_test(range(10000000), *steps)\n    ...:     \ntest in w/ copy              28.9 s \xc2\xb1 0 ns per loop (mean \xc2\xb1 std. dev. of 1 run, 1 loop each)\ntest exists w/ copy          29.3 s \xc2\xb1 0 ns per loop (mean \xc2\xb1 std. dev. of 1 run, 1 loop each)\ntest join w/ copy            29.7 s \xc2\xb1 0 ns per loop (mean \xc2\xb1 std. dev. of 1 run, 1 loop each)\ntest in w/ copy (uniq)       28.3 s \xc2\xb1 0 ns per loop (mean \xc2\xb1 std. dev. of 1 run, 1 loop each)\ntest exists w/ copy (uniq)   27.5 s \xc2\xb1 0 ns per loop (mean \xc2\xb1 std. dev. of 1 run, 1 loop each)\ntest join w/ copy (uniq)     28.4 s \xc2\xb1 0 ns per loop (mean \xc2\xb1 std. dev. of 1 run, 1 loop each)\n
Run Code Online (Sandbox Code Playgroud)\n\n

因此,对于小型密钥集(~100,000 或更少),您使用什么并不重要,尽管bindparam与易用性相比,使用扩展在时间上显然是赢家,但对于更大的密钥集,您可能需要考虑使用一个临时表和COPY.

\n\n

值得注意的是,对于大型集合,如果使用唯一索引,查询计划是相同的:

\n\n
In [13]: print(*run_test(range(10000000),\n    ...:                 tmp_tbl_factories["copy (uniq)"],\n    ...:                 query_factories["in"],\n    ...:                 cont=explain()), sep="\\n")\nMerge Join  (cost=45.44..760102.11 rows=9999977 width=4)\n  Merge Cond: (mytable.my_key = x.k)\n  ->  Index Only Scan using mytable_pkey on mytable  (cost=0.44..607856.88 rows=20000096 width=4)\n  ->  Index Only Scan using x_k_idx on x  (cost=0.43..303939.09 rows=9999977 width=4)\n\nIn [14]: print(*run_test(range(10000000),\n    ...:                 tmp_tbl_factories["copy (uniq)"],\n    ...:                 query_factories["exists"],\n    ...:                 cont=explain()), sep="\\n")\nMerge Join  (cost=44.29..760123.36 rows=9999977 width=4)\n  Merge Cond: (mytable.my_key = x.k)\n  ->  Index Only Scan using mytable_pkey on mytable  (cost=0.44..607856.88 rows=20000096 width=4)\n  ->  Index Only Scan using x_k_idx on x  (cost=0.43..303939.09 rows=9999977 width=4)\n\nIn [15]: print(*run_test(range(10000000),\n    ...:                 tmp_tbl_factories["copy (uniq)"],\n    ...:                 query_factories["join"],\n    ...:                 cont=explain()), sep="\\n")\nMerge Join  (cost=39.06..760113.29 rows=9999977 width=4)\n  Merge Cond: (mytable.my_key = x.k)\n  ->  Index Only Scan using mytable_pkey on mytable  (cost=0.44..607856.88 rows=20000096 width=4)\n  ->  Index Only Scan using x_k_idx on x  (cost=0.43..303939.09 rows=9999977 width=4)\n
Run Code Online (Sandbox Code Playgroud)\n\n

由于测试表是人造的,因此它能够使用仅索引扫描。

\n\n
\n\n

最后,以下是“行人”方法的时间安排,以进行粗略比较:

\n\n
In [3]: for ksl in [100000, 1000000]:\n   ...:     %time [session.query(Table).get(k) for k in range(ksl)]\n   ...:     session.rollback()\n   ...:     \nCPU times: user 1min, sys: 1.76 s, total: 1min 1s\nWall time: 1min 13s\nCPU times: user 9min 48s, sys: 17.3 s, total: 10min 5s\nWall time: 12min 1s\n
Run Code Online (Sandbox Code Playgroud)\n\n

问题是使用Query.get()必然包含 ORM,而原来的比较却没有。不过,很明显,即使使用本地数据库,单独的数据库往返成本也很高。

\n