测试 sqlalchemy select 对象中的 where 子句

Roa*_*ord 1 python testing unit-testing sqlalchemy

我正在尝试编写一些函数来构建 sqlalchemy select 语句。例如:

import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()
metadata = Base.metadata

t_test = sa.Table(
    'test', metadata,
    sa.Column('col_1', sa.Text),
    sa.Column('col_2', sa.Float)
)

def create_test_select():
    sa_select = sa.select([t_test.c.col_1, t_test.c.col_2])
    return sa_select

def add_test_col1_where_clause(sa_select, x ):
    sa_select = sa_select.where(t_test.c.col_1 == x)
    return sa_select
Run Code Online (Sandbox Code Playgroud)

我想测试一下这些功能。

为了测试create_test_select我会写一些类似的东西

class Test(unittest.TestCase):    
    def test(self):
         self.assertIn('col1', create_test_select().columns)
         self.assertIn('col2', create_test_select().columns)
Run Code Online (Sandbox Code Playgroud)

我如何测试该功能add_test_col1_where_clause?我想知道它向选择添加了正确的 where 子句。我最初的想法是检查 sqlachemy select 对象中的 where 子句,但我无法让它工作。

Mar*_*ers 5

SQLAlchemy 不直接公开 select 的“where”子句部分;毕竟,并非所有 SELECT 语句都有一个。此外,该条款可能会变得相当复杂。就我个人而言,我只会在集成测试中测试表达式,并且只是为了确保返回正确的数据。

SQLAlchemy 确实为您提供了访问对象树的工具,尽管记录有些不足。您可以使用它从树中提取任何比较,因此<left> <op> <right>表达式 whereleftrightare columns elements

from sqlalchemy.sql import visitors, ColumnElement

def comparison_visitor(expr, callback):
    """Finds all binary operators, and calls callback(op, left, right)"""
    def visit_binary(op):
        callback(op, op.left, op.right)
    # visit each expr element, but for select clauses, ignore the column collection
    visitors.traverse(expr, {'column_collections': False}, {'binary': visit_binary})
Run Code Online (Sandbox Code Playgroud)

visitors.traverse()遍历任何 SQLAlchemy 表达式(通过重复调用对象的ColumnClause.get_children()方法,将作为第二个参数给出的映射传递给traverse()),并调用与每个对象的属性匹配的函数__visit_name__BinaryClause对象具有binary访问名称。

然后,您可以使用它来测试是否存在特定条件:

from sqlalchemy.sql import operators
from sqlalchemy.sql.expressions import BindParameter
from sqlalchemy.sql.schema import Column

def test_where(self):
    where = add_test_col1_where_clause(create_test_select(), 'foo bar')

    def test_comparison(op, left, right):
        self.assertIs(op.operator, operators.eq)    # == test

        self.assertIsInstance(left, Column)         # between the col_1 column
        self.assertEq(left.name, 'col_1')

        self.assertIsInstance(right, BindParameter) # and a parameter
        self.assertEq(right.value, 'foo bar')       # with value 'foo bar'

    comparison_visitor(where, test_comparison)
Run Code Online (Sandbox Code Playgroud)