为 pytest 模拟 Sqlalchemy 会话

Not*_*tMe 12 python sqlalchemy mocking pytest flask-sqlalchemy

我不知道这是否可以完成,但我正在尝试模拟我的 db.session.save。

我正在使用烧瓶和烧瓶炼金术。

数据库文件

from flask_sqlalchemy import SQLAlchemy

db = SQLAlchemy()
Run Code Online (Sandbox Code Playgroud)

单元测试

def test_post(self):
    with app.app_context():
        with app.test_client() as client:
            with mock.patch('models.db.session.save') as mock_save:
                with mock.patch('models.db.session.commit') as mock_commit:

                    data = self.gen_legend_data()
                    response = client.post('/legends', data=json.dumps([data]), headers=access_header)

                    assert response.status_code == 200
                    mock_save.assert_called()
                    mock_commit.assert_called_once()
Run Code Online (Sandbox Code Playgroud)

和方法:

def post(cls):
    legends = schemas.Legends(many=True).load(request.get_json())

    for legend in legends:
        db.session.add(legend)

    db.session.commit()

    return {'message': 'legends saved'}, 200
Run Code Online (Sandbox Code Playgroud)

我正在尝试模拟 db.session.add 和 db.session.commit。我试过db.session.savelegends.models.db.session.savemodels.db.session.save。他们都带着保存错误回来了:

ModuleNotFoundError: No module named 'models.db.session'; 'models.db' is not a package
Run Code Online (Sandbox Code Playgroud)

我没有收到错误,也不知道如何解决。

或者我在做一些想要模拟 db.session 的事情是完全错误的?

谢谢。德斯蒙德

Ste*_*hry 20

您在这里遇到的问题最好通过重组您的代码来解决,以便它更易于测试,而不是模拟每个组件,或者进行(非常)缓慢的集成测试。如果你养成了以这种方式编写测试的习惯,那么随着时间的推移,你最终会得到一个运行缓慢的构建,而且你最终会得到脆弱的测试(关于为什么的好话题)快速测试在这里很重要)。

我们来看看这条路线:

def post(cls):
    legends = schemas.Legends(many=True).load(request.get_json())

    for legend in legends:
        db.session.add(legend)

    db.session.commit()

    return {'message': 'legends saved'}, 200
Run Code Online (Sandbox Code Playgroud)

...并分解它:

import typing
from flask import jsonify

class LegendsPostService:

    def __init__(self, json_args, _session=None) -> None:
        self.json_args = json_args
        self.session = _session or db.session

    def _get_legends(self) -> Legend:
        return schemas.Legends(many=True).load(self.json_args)

    def post(self) -> typing.List[typing.Dict[str, typing.Any]]:
        legends = self._get_legends()

        for legend in legends:
            self.session.add(legend)

        self.session.commit()
        return schemas.Legends(many=True).dump(legends)

def post(cls):
    service = LegendsPostService(json_args=request.get_json())
    service.post()
    return jsonify({'message': 'legends saved'})

Run Code Online (Sandbox Code Playgroud)

请注意我们如何将几乎所有的故障点从postinto 中隔离出来LegendsPostService,此外,我们还从中删除了所有烧瓶内部结构(没有浮动的全局请求对象等)。session如果我们需要稍后进行测试,我们甚至赋予它模拟的能力。

我建议您将测试工作集中在为LegendsPostService. 一旦您获得了出色的 测试LegendsPostService,请决定您是否相信更多的测试覆盖率会增加价值。如果你这样做了,那么考虑编写一个简单的集成测试post()来将它们联系在一起。

您需要考虑的下一件事是您希望如何考虑测试中的 SQLAlchemy 对象。我建议只使用 FactoryBoy 为您自动创建“模拟”模型。这是一个完整的应用程序示例,用于说明如何以这种方式设置flask/sqlalchemy/factory-boy:How do I generate nested JSON from database query with joins?使用 Python/SQLAlchemy

这是我编写测试的方式LegendsPostService(抱歉,这有点草率,并不能完全代表您要执行的操作 - 但您应该能够为您的用例调整这些测试):


from factory.alchemy import SQLAlchemyModelFactory

class ModelFactory(SQLAlchemyModelFactory):
    class Meta:
        abstract = True
        sqlalchemy_session = db.session

# setup your factory for Legends:
class LegendsFactory(ModelFactory):
    logo_url = factory.Faker('image_url')
    class Meta(ModelFactory.Meta):
        model = Legends


from unittest.mock import MagicMock, patch


# neither of these tests even need a database connection!
# so you should be able to write HUNDREDS of similar tests
# and you should be able to run hundreds of them in seconds (not minutes)

def test_LegendsPostService_can_init():
    session = MagicMock()
    service = LegendsPostService(json_args={'foo': 'bar'}, _session=session)
    assert service.session is session
    assert service.json_args['foo'] == 'bar'


def test_LegendsPostService_can_post():
    session = MagicMock()
    service = LegendsPostService(json_args={'foo': 'bar'}, _session=session)

    # let's make some fake Legends for our service!
    legends = LegendsFactory.build_batch(2)

    with patch.object(service, '_get_legends') as _get_legends:
        _get_legends.return_value = legends
        legends_post_json = service.post()

    # look, Ma! No database connection!
    assert legends_post_json[0]['image_url'] == legends[0].image_url

Run Code Online (Sandbox Code Playgroud)

我希望这有帮助!