如何为 @task 装饰的 Airflow 任务编写单元测试?

Sad*_*had 5 python-unittest airflow airflow-taskflow

我正在尝试为使用Airflow TaskFlow API构建的一些任务编写单元测试。我尝试了多种方法,例如,通过创建 dagrun 或仅运行任务函数,但没有任何帮助。

这是我从 S3 下载文件的任务,还有更多内容,但我在本示例中删除了它。

@task()
def updates_process(files):
    context = get_current_context()
    try:
        updates_file_path = utils.download_file_from_s3_bucket(files.get("updates_file"))
    except FileNotFoundError as e:
        log.error(e)
        return

    # Do something else
Run Code Online (Sandbox Code Playgroud)

现在我试图编写一个测试用例,我可以在其中检查这个 except 子句。以下是我开始的例子

class TestAccountLinkUpdatesProcess(TestCase):
    @mock.patch("dags.delta_load.updates.log")
    @mock.patch("dags.delta_load.updates.get_current_context")
    @mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
    def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
        download_file_from_s3_bucket.side_effect = FileNotFoundError
        task = account_link_updates_process({"updates_file": "path/to/file.csv"})
        get_current_context.assert_called_once()
        log.error.assert_called_once()
Run Code Online (Sandbox Code Playgroud)

我还尝试创建一个 dagrun(如文档中的示例所示)并从 dagrun 获取任务,但这也没有帮助。

Sad*_*had 0

这就是我能弄清楚的。不确定这是否正确,但它确实有效。

class TestAccountLinkUpdatesProcess(TestCase):
    TASK_ID = "updates_process"

    @classmethod
    def setUpClass(cls) -> None:
        cls.dag = dag_delta_load()

    @mock.patch("dags.delta_load.updates.log")
    @mock.patch("dags.delta_load.updates.get_current_context")
    @mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
    def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
        download_file_from_s3_bucket.side_effect = FileNotFoundError
        task = self.dag.get_task(task_id=self.TASK_ID)
        task.op_args = [{"updates_file": "file.csv"}]
        task.execute(context={})
        log.error.assert_called_once()
Run Code Online (Sandbox Code Playgroud)

更新:根据@AetherUnbound的回答,我做了一些调查,发现我们可以用来task.__wrapped__()调用实际的python函数。

class TestAccountLinkUpdatesProcess(TestCase):
    @mock.patch("dags.delta_load.updates.log")
    @mock.patch("dags.delta_load.updates.get_current_context")
    @mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
    def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
        download_file_from_s3_bucket.side_effect = FileNotFoundError
        update_process.__wrapped__({"updates_file": "file.csv"})
        log.error.assert_called_once()
Run Code Online (Sandbox Code Playgroud)