气流:当某些上游被短路跳过时运行任务

Jus*_*zas 10 airflow

我有一个任务,我将调用final它具有多个上游连接。当ShortCircuitOperator此任务跳过上游之一时,也会跳过。我不希望final任务被跳过,因为它必须报告 DAG 成功。

为了避免它被跳过,我使用了trigger_rule='all_done',但它仍然被跳过。

如果我使用BranchPythonOperator而不是ShortCircuitOperator final任务不会被跳过。看起来分支工作流可能是一个解决方案,即使不是最优的,但现在final不会考虑上游任务的失败。

如何让它仅在上游成功或跳过时运行?

示例短路 DAG:

from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import ShortCircuitOperator
from datetime import datetime
from random import randint

default_args = {
    'owner': 'airflow',
    'start_date': datetime(2018, 8, 1)}

dag = DAG(
    'shortcircuit_test',
    default_args=default_args,
    schedule_interval='* * * * *',
    catchup=False)

def shortcircuit_fn():
    return randint(0, 1) == 1

task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')

work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=shortcircuit_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")

task_1 >> short >> work >> final
task_1 >> task_2 >> final
Run Code Online (Sandbox Code Playgroud)

带有短路运算符的 DAG

示例分支 DAG:

from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import BranchPythonOperator
from datetime import datetime
from random import randint

default_args = {
    'owner': 'airflow',
    'start_date': datetime(2018, 8, 1)}

dag = DAG(
    'branch_test',
    default_args=default_args,
    schedule_interval='* * * * *',
    catchup=False)

# these two are only here to protect tasks from getting skipped as direct dependencies of branch operator
to_do_work = DummyOperator(dag=dag, task_id='to_do_work')
to_skip_work = DummyOperator(dag=dag, task_id='to_skip_work')

def branch_fn():
    return to_do_work.task_id if randint(0, 1) == 1 else to_skip_work.task_id

task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')

work = DummyOperator(dag=dag, task_id='work')
branch = BranchPythonOperator(dag=dag, task_id='branch', python_callable=branch_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")

task_1 >> branch >> to_do_work >> work >> final
branch >> to_skip_work >> final
task_1 >> task_2 >> final
Run Code Online (Sandbox Code Playgroud)

带有分支操作符的 DAG

Mic*_*tor 13

我最终开发了基于原始 ShortCircuitOperator 的自定义 ShortCircuitOperator:

class ShortCircuitOperator(PythonOperator, SkipMixin):
    """
    Allows a workflow to continue only if a condition is met. Otherwise, the
    workflow "short-circuits" and downstream tasks that only rely on this operator
    are skipped.

    The ShortCircuitOperator is derived from the PythonOperator. It evaluates a
    condition and short-circuits the workflow if the condition is False. Any
    downstream tasks that only rely on this operator are marked with a state of "skipped".
    If the condition is True, downstream tasks proceed as normal.

    The condition is determined by the result of `python_callable`.
    """

    def find_tasks_to_skip(self, task, found_tasks=None):
        if not found_tasks:
            found_tasks = []
        direct_relatives = task.get_direct_relatives(upstream=False)
        for t in direct_relatives:
            if len(t.upstream_task_ids) == 1:
                found_tasks.append(t)
                self.find_tasks_to_skip(t, found_tasks)
        return found_tasks

    def execute(self, context):
        condition = super(ShortCircuitOperator, self).execute(context)
        self.log.info("Condition result is %s", condition)

        if condition:
            self.log.info('Proceeding with downstream tasks...')
            return

        self.log.info(
            'Skipping downstream tasks that only rely on this path...')

        tasks_to_skip = self.find_tasks_to_skip(context['task'])
        self.log.debug("Tasks to skip: %s", tasks_to_skip)

        if tasks_to_skip:
            self.skip(context['dag_run'], context['ti'].execution_date,
                      tasks_to_skip)

        self.log.info("Done.")
Run Code Online (Sandbox Code Playgroud)

此操作符确保不会因为跳过一个任务而跳过依赖多个路径的下游任务。

  • 这正是我所需要的。内置 ShortCircuit 运算符似乎会跳过所有下游任务,而不进行检查,例如触发规则或其他注意事项。 (5认同)

Rob*_*hts 6

我为此发布了另一种可能的解决方法,因为这是一种不需要自定义运算符实现的方法。

我受到本博客中使用 PythonOperator 的解决方案的影响,该解决方案引发 AirflowSkipException,该异常跳过任务本身,然后单独跳过下游任务。

https://godatadriven.com/blog/the-zen-of-python-and-apache-airflow/

然后这尊重最终下游任务的 trigger_rule,在我的情况下,我设置为trigger_rule='none_failed'.

根据博客修改示例以包含最终任务:

def fn_short_circuit(**context):
    if <<<some condition>>>:
        raise AirflowSkipException("Skip this task and individual downstream tasks while respecting trigger rules.")

check_date = PythonOperator(
    task_id="check_if_min_date",
    python_callable=_check_date,
    provide_context=True,
    dag=dag,
)

task1 = DummyOperator(task_id="task1", dag=dag)
task2 = DummyOperator(task_id="task2", dag=dag)
work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=fn_short_circuit
final_task = DummyOperator(task_id="final_task",
    trigger_rule='none_failed',
    dag=dag)


task_1 >> short >> work >> final_task
task_1 >> task_2 >> final_task
Run Code Online (Sandbox Code Playgroud)


Jus*_*zas 1

final我通过创建任务来检查上游实例的状态来使其工作。这并不美观,因为我发现访问其状态的唯一方法是查询 Airflow DB。

# # additional imports to ones in question code
# from airflow import AirflowException
# from airflow.models import TaskInstance
# from airflow.operators.python_operator import PythonOperator
# from airflow.settings import Session
# from airflow.utils.state import State
# from airflow.utils.trigger_rule import TriggerRule

def all_upstreams_either_succeeded_or_skipped(dag, task, task_instance, **context):
    """
    find directly upstream task instances and count how many are not in prefered statuses.
    return True if we got no instances with non-preferred statuses.
    """
    upstream_task_ids = [t.task_id for t in task.get_direct_relatives(upstream=True)]
    session = Session()
    query = (session
        .query(TaskInstance)
        .filter(
            TaskInstance.dag_id == dag.dag_id,
            TaskInstance.execution_date.in_([task_instance.execution_date]),
            TaskInstance.task_id.in_(upstream_task_ids)
        )
    )
    upstream_task_instances = query.all()
    unhappy_task_instances = [ti for ti in upstream_task_instances if ti.state not in [State.SUCCESS, State.SKIPPED]]
    print(unhappy_task_instances)
    return len(unhappy_task_instances) == 0

def final_fn(**context):
    """
    fail if upstream task instances have unwanted statuses
    """
    if not all_upstreams_either_succeeded_or_skipped(**context):
        raise AirflowException("Not all upstream tasks succeeded.")
    # Do things

# will run when upstream task instances are done, including failed
final = PythonOperator(
    dag=dag,
    task_id="final",
    trigger_rule=TriggerRule.ALL_DONE,
    python_callable=final_fn,
    provide_context=True)
Run Code Online (Sandbox Code Playgroud)