diff --git a/cosmos/operators/kubernetes.py b/cosmos/operators/kubernetes.py index 996bbc9dd..af0988a6a 100644 --- a/cosmos/operators/kubernetes.py +++ b/cosmos/operators/kubernetes.py @@ -4,12 +4,17 @@ from typing import Any, Callable, Sequence import yaml -from airflow.utils.context import Context +from airflow.utils.context import Context, context_merge from cosmos.log import get_logger from cosmos.config import ProfileConfig from cosmos.operators.base import DbtBaseOperator +from airflow.models import TaskInstance +from cosmos.dbt.parser.output import extract_log_issues + +DBT_NO_TESTS_MSG = "Nothing to do" +DBT_WARN_MSG = "WARN" logger = get_logger(__name__) @@ -19,6 +24,7 @@ convert_env_vars, ) from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator + from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction except ImportError: try: # apache-airflow-providers-cncf-kubernetes < 7.4.0 @@ -158,10 +164,96 @@ class DbtTestKubernetesOperator(DbtKubernetesBaseOperator): ui_color = "#8194E0" def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: Any) -> None: - super().__init__(**kwargs) + if not on_warning_callback: + super().__init__(**kwargs) + else: + self.on_warning_callback = on_warning_callback + self.is_delete_operator_pod_original = kwargs.get("is_delete_operator_pod", None) + if self.is_delete_operator_pod_original is not None: + self.on_finish_action_original = ( + OnFinishAction.DELETE_POD if self.is_delete_operator_pod_original else OnFinishAction.KEEP_POD + ) + else: + self.on_finish_action_original = OnFinishAction(kwargs.get("on_finish_action", "delete_pod")) + self.is_delete_operator_pod_original = self.on_finish_action_original == OnFinishAction.DELETE_POD + # In order to read the pod logs, we need to keep the pod around. + # Depending on the on_finish_action & is_delete_operator_pod settings, + # we will clean up the pod later in the _handle_warnings method, which + # is called in on_success_callback. + kwargs["is_delete_operator_pod"] = False + kwargs["on_finish_action"] = OnFinishAction.KEEP_POD + + # Add an additional callback to both success and failure callbacks. + # In case of success, check for a warning in the logs and clean up the pod. + self.on_success_callback = kwargs.get("on_success_callback", None) or [] + if isinstance(self.on_success_callback, list): + self.on_success_callback += [self._handle_warnings] + else: + self.on_success_callback = [self.on_success_callback, self._handle_warnings] + kwargs["on_success_callback"] = self.on_success_callback + # In case of failure, clean up the pod. + self.on_failure_callback = kwargs.get("on_failure_callback", None) or [] + if isinstance(self.on_failure_callback, list): + self.on_failure_callback += [self._cleanup_pod] + else: + self.on_failure_callback = [self.on_failure_callback, self._cleanup_pod] + kwargs["on_failure_callback"] = self.on_failure_callback + + super().__init__(**kwargs) + self.base_cmd = ["test"] - # as of now, on_warning_callback in kubernetes executor does nothing - self.on_warning_callback = on_warning_callback + + def _handle_warnings(self, context: Context) -> None: + """ + Handles warnings by extracting log issues, creating additional context, and calling the + on_warning_callback with the updated context. + + :param context: The original airflow context in which the build and run command was executed. + """ + if not ( + isinstance(context["task_instance"], TaskInstance) + and isinstance(context["task_instance"].task, DbtTestKubernetesOperator) + ): + return + task = context["task_instance"].task + logs = [ + log.decode("utf-8") for log in task.pod_manager.read_pod_logs(task.pod, "base") if log.decode("utf-8") != "" + ] + + should_trigger_callback = all( + [ + logs, + self.on_warning_callback, + DBT_NO_TESTS_MSG not in logs[-1], + DBT_WARN_MSG in logs[-1], + ] + ) + + if should_trigger_callback: + warnings = int(logs[-1].split(f"{DBT_WARN_MSG}=")[1].split()[0]) + if warnings > 0: + test_names, test_results = extract_log_issues(logs) + context_merge(context, test_names=test_names, test_results=test_results) + self.on_warning_callback(context) + + self._cleanup_pod(context) + + def _cleanup_pod(self, context: Context) -> None: + """ + Handles the cleaning up of the pod after success or failure, if + there is a on_warning_callback function defined. + + :param context: The original airflow context in which the build and run command was executed. + """ + if not ( + isinstance(context["task_instance"], TaskInstance) + and isinstance(context["task_instance"].task, DbtTestKubernetesOperator) + ): + return + task = context["task_instance"].task + if task.pod: + task.on_finish_action = self.on_finish_action_original + task.cleanup(pod=task.pod, remote_pod=task.remote_pod) class DbtRunOperationKubernetesOperator(DbtKubernetesBaseOperator): diff --git a/tests/operators/test_kubernetes.py b/tests/operators/test_kubernetes.py index 7ef606cfe..585b1ab32 100644 --- a/tests/operators/test_kubernetes.py +++ b/tests/operators/test_kubernetes.py @@ -1,7 +1,7 @@ from pathlib import Path from unittest.mock import MagicMock, patch -from airflow.utils.context import Context +import pytest from pendulum import datetime from cosmos.operators.kubernetes import ( @@ -12,6 +12,16 @@ DbtTestKubernetesOperator, ) +from airflow.utils.context import Context, context_merge +from airflow.models import TaskInstance + +try: + from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction + + module_available = True +except ImportError: + module_available = False + def test_dbt_kubernetes_operator_add_global_flags() -> None: dbt_kube_operator = DbtKubernetesBaseOperator( @@ -103,6 +113,113 @@ def test_dbt_kubernetes_build_command(): ] +@pytest.mark.parametrize( + "additional_kwargs,expected_results", + [ + ({"on_success_callback": None, "is_delete_operator_pod": True}, (1, 1, True, "delete_pod")), + ( + {"on_success_callback": (lambda **kwargs: None), "is_delete_operator_pod": False}, + (2, 1, False, "keep_pod"), + ), + ( + {"on_success_callback": [(lambda **kwargs: None), (lambda **kwargs: None)], "is_delete_operator_pod": None}, + (3, 1, True, "delete_pod"), + ), + ( + {"on_failure_callback": None, "is_delete_operator_pod": True, "on_finish_action": "keep_pod"}, + (1, 1, True, "delete_pod"), + ), + ( + { + "on_failure_callback": (lambda **kwargs: None), + "is_delete_operator_pod": None, + "on_finish_action": "delete_pod", + }, + (1, 2, True, "delete_pod"), + ), + ( + { + "on_failure_callback": [(lambda **kwargs: None), (lambda **kwargs: None)], + "is_delete_operator_pod": None, + "on_finish_action": "delete_succeeded_pod", + }, + (1, 3, False, "delete_succeeded_pod"), + ), + ({"is_delete_operator_pod": None, "on_finish_action": "keep_pod"}, (1, 1, False, "keep_pod")), + ({}, (1, 1, True, "delete_pod")), + ], +) +@pytest.mark.skipif( + not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available" +) +def test_dbt_test_kubernetes_operator_constructor(additional_kwargs, expected_results): + test_operator = DbtTestKubernetesOperator( + on_warning_callback=(lambda **kwargs: None), **additional_kwargs, **base_kwargs + ) + + print(additional_kwargs, test_operator.__dict__) + + assert isinstance(test_operator.on_success_callback, list) + assert isinstance(test_operator.on_failure_callback, list) + assert test_operator._handle_warnings in test_operator.on_success_callback + assert test_operator._cleanup_pod in test_operator.on_failure_callback + assert len(test_operator.on_success_callback) == expected_results[0] + assert len(test_operator.on_failure_callback) == expected_results[1] + assert test_operator.is_delete_operator_pod_original == expected_results[2] + assert test_operator.on_finish_action_original == OnFinishAction(expected_results[3]) + + +class FakePodManager: + def read_pod_logs(self, pod, container): + assert pod == "pod" + assert container == "base" + log_string = """ +19:48:25 Concurrency: 4 threads (target='target') +19:48:25 +19:48:25 1 of 2 START test dbt_utils_accepted_range_table_col__12__0 ................... [RUN] +19:48:25 2 of 2 START test unique_table__uuid .......................................... [RUN] +19:48:27 1 of 2 WARN 252 dbt_utils_accepted_range_table_col__12__0 ..................... [WARN 117 in 1.83s] +19:48:27 2 of 2 PASS unique_table__uuid ................................................ [PASS in 1.85s] +19:48:27 +19:48:27 Finished running 2 tests, 1 hook in 0 hours 0 minutes and 12.86 seconds (12.86s). +19:48:27 +19:48:27 Completed with 1 warning: +19:48:27 +19:48:27 Warning in test dbt_utils_accepted_range_table_col__12__0 (models/ads/ads.yaml) +19:48:27 Got 252 results, configured to warn if >0 +19:48:27 +19:48:27 compiled Code at target/compiled/model/models/table/table.yaml/dbt_utils_accepted_range_table_col__12__0.sql +19:48:27 +19:48:27 Done. PASS=1 WARN=1 ERROR=0 SKIP=0 TOTAL=2 +""" + return (log.encode("utf-8") for log in log_string.split("\n")) + + +@pytest.mark.skipif( + not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available" +) +def test_dbt_test_kubernetes_operator_handle_warnings_and_cleanup_pod(): + def on_warning_callback(context: Context): + assert context["test_names"] == ["dbt_utils_accepted_range_table_col__12__0"] + assert context["test_results"] == ["Got 252 results, configured to warn if >0"] + + def cleanup(pod: str, remote_pod: str): + assert pod == remote_pod + + test_operator = DbtTestKubernetesOperator( + is_delete_operator_pod=True, on_warning_callback=on_warning_callback, **base_kwargs + ) + task_instance = TaskInstance(test_operator) + task_instance.task.pod_manager = FakePodManager() + task_instance.task.pod = task_instance.task.remote_pod = "pod" + task_instance.task.cleanup = cleanup + + context = Context() + context_merge(context, task_instance=task_instance) + + test_operator._handle_warnings(context) + + @patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.hook") def test_created_pod(test_hook): test_hook.is_in_cluster = False