Skip to content

Commit

Permalink
Resolve skipmixin deprecations in tests (apache#39971)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirrao authored May 31, 2024
1 parent 5137aef commit 93e6f00
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
5 changes: 0 additions & 5 deletions tests/deprecations_ignore.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,6 @@
- tests/models/test_dagbag.py::TestDagBag::test_load_subdags
- tests/models/test_dagbag.py::TestDagBag::test_skip_cycle_dags
- tests/models/test_mappedoperator.py::test_expand_mapped_task_instance_with_named_index
- tests/models/test_skipmixin.py::TestSkipMixin::test_mapped_tasks_skip_all_except
- tests/models/test_skipmixin.py::TestSkipMixin::test_raise_exception_on_not_accepted_branch_task_ids_type
- tests/models/test_skipmixin.py::TestSkipMixin::test_raise_exception_on_not_accepted_iterable_branch_task_ids_type
- tests/models/test_skipmixin.py::TestSkipMixin::test_raise_exception_on_not_valid_branch_task_ids
- tests/models/test_skipmixin.py::TestSkipMixin::test_skip_all_except
- tests/models/test_skipmixin.py::TestSkipMixin::test_skip_none_dagrun
- tests/models/test_taskinstance.py::TestTaskInstance::test_context_triggering_dataset_events
- tests/models/test_taskinstance.py::TestTaskInstance::test_get_num_running_task_instances
Expand Down
44 changes: 25 additions & 19 deletions tests/models/test_skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from airflow import settings
from airflow.decorators import task, task_group
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import TaskInstance as TI
from airflow.operators.empty import EmptyOperator
Expand All @@ -38,6 +38,7 @@


DEFAULT_DATE = timezone.datetime(2016, 1, 1)
DEFAULT_DAG_RUN_ID = "test1"


class TestSkipMixin:
Expand Down Expand Up @@ -85,7 +86,12 @@ def test_skip_none_dagrun(self, mock_now, dag_maker):
):
tasks = [EmptyOperator(task_id="task")]
dag_maker.create_dagrun(execution_date=now)
SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks, session=session)

with pytest.warns(
RemovedInAirflow3Warning,
match=r"Passing an execution_date to `skip\(\)` is deprecated in favour of passing a dag_run",
):
SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks, session=session)

session.query(TI).filter(
TI.dag_id == "dag",
Expand Down Expand Up @@ -121,11 +127,11 @@ def test_skip_all_except(self, dag_maker, branch_task_ids, expected_states):
task3 = EmptyOperator(task_id="task3")

task1 >> [task2, task3]
dag_maker.create_dagrun()
dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID)

ti1 = TI(task1, execution_date=DEFAULT_DATE)
ti2 = TI(task2, execution_date=DEFAULT_DATE)
ti3 = TI(task3, execution_date=DEFAULT_DATE)
ti1 = TI(task1, run_id=DEFAULT_DAG_RUN_ID)
ti2 = TI(task2, run_id=DEFAULT_DAG_RUN_ID)
ti3 = TI(task3, run_id=DEFAULT_DAG_RUN_ID)

SkipMixin().skip_all_except(ti=ti1, branch_task_ids=branch_task_ids)

Expand All @@ -151,13 +157,13 @@ def task_group_op(k):

task_group_op.expand(k=[0, 1])

dag_maker.create_dagrun()
branch_op_ti_0 = TI(dag.get_task("task_group_op.branch_op"), execution_date=DEFAULT_DATE, map_index=0)
branch_op_ti_1 = TI(dag.get_task("task_group_op.branch_op"), execution_date=DEFAULT_DATE, map_index=1)
branch_a_ti_0 = TI(dag.get_task("task_group_op.branch_a"), execution_date=DEFAULT_DATE, map_index=0)
branch_a_ti_1 = TI(dag.get_task("task_group_op.branch_a"), execution_date=DEFAULT_DATE, map_index=1)
branch_b_ti_0 = TI(dag.get_task("task_group_op.branch_b"), execution_date=DEFAULT_DATE, map_index=0)
branch_b_ti_1 = TI(dag.get_task("task_group_op.branch_b"), execution_date=DEFAULT_DATE, map_index=1)
dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID)
branch_op_ti_0 = TI(dag.get_task("task_group_op.branch_op"), run_id=DEFAULT_DAG_RUN_ID, map_index=0)
branch_op_ti_1 = TI(dag.get_task("task_group_op.branch_op"), run_id=DEFAULT_DAG_RUN_ID, map_index=1)
branch_a_ti_0 = TI(dag.get_task("task_group_op.branch_a"), run_id=DEFAULT_DAG_RUN_ID, map_index=0)
branch_a_ti_1 = TI(dag.get_task("task_group_op.branch_a"), run_id=DEFAULT_DAG_RUN_ID, map_index=1)
branch_b_ti_0 = TI(dag.get_task("task_group_op.branch_b"), run_id=DEFAULT_DAG_RUN_ID, map_index=0)
branch_b_ti_1 = TI(dag.get_task("task_group_op.branch_b"), run_id=DEFAULT_DAG_RUN_ID, map_index=1)

SkipMixin().skip_all_except(ti=branch_op_ti_0, branch_task_ids="task_group_op.branch_a")
SkipMixin().skip_all_except(ti=branch_op_ti_1, branch_task_ids="task_group_op.branch_b")
Expand All @@ -174,8 +180,8 @@ def get_state(ti):
def test_raise_exception_on_not_accepted_branch_task_ids_type(self, dag_maker):
with dag_maker("dag_test_skip_all_except_wrong_type"):
task = EmptyOperator(task_id="task")
dag_maker.create_dagrun()
ti1 = TI(task, execution_date=DEFAULT_DATE)
dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID)
ti1 = TI(task, run_id=DEFAULT_DAG_RUN_ID)
error_message = (
r"'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, but got 'int'\."
)
Expand All @@ -185,8 +191,8 @@ def test_raise_exception_on_not_accepted_branch_task_ids_type(self, dag_maker):
def test_raise_exception_on_not_accepted_iterable_branch_task_ids_type(self, dag_maker):
with dag_maker("dag_test_skip_all_except_wrong_type"):
task = EmptyOperator(task_id="task")
dag_maker.create_dagrun()
ti1 = TI(task, execution_date=DEFAULT_DATE)
dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID)
ti1 = TI(task, run_id=DEFAULT_DAG_RUN_ID)
error_message = (
r"'branch_task_ids' expected all task IDs are strings. "
r"Invalid tasks found: \{\(42, 'int'\)\}\."
Expand All @@ -209,9 +215,9 @@ def test_raise_exception_on_not_valid_branch_task_ids(self, dag_maker, branch_ta
task3 = EmptyOperator(task_id="task3")

task1 >> [task2, task3]
dag_maker.create_dagrun()
dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID)

ti1 = TI(task1, execution_date=DEFAULT_DATE)
ti1 = TI(task1, run_id=DEFAULT_DAG_RUN_ID)

error_message = r"'branch_task_ids' must contain only valid task_ids. Invalid tasks found: .*"
with pytest.raises(AirflowException, match=error_message):
Expand Down

0 comments on commit 93e6f00

Please sign in to comment.