diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 869c2087c..249a2d0ee 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -198,7 +198,6 @@ def create_task_metadata( else: task_id = f"{node.name}_run" elif node.resource_type == DbtResourceType.SOURCE: - # if on_warning_callback is not None: extra_context["on_warning_callback"] = on_warning_callback if (source_rendering_behavior == SourceRenderingBehavior.NONE) or ( diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index 057056d36..f26568ee7 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -37,9 +37,9 @@ def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None) # Set the on_warning_callback of source node in task_kwargs on_warning_callback = task.extra_context.get("on_warning_callback") - if on_warning_callback is not None: + if on_warning_callback: task_kwargs["on_warning_callback"] = on_warning_callback - del task.extra_context["on_warning_callback"] + task.extra_context.pop("on_warning_callback", None) airflow_task = Operator( task_id=task.id, diff --git a/cosmos/dbt/parser/output.py b/cosmos/dbt/parser/output.py index 4232ab380..06def6b28 100644 --- a/cosmos/dbt/parser/output.py +++ b/cosmos/dbt/parser/output.py @@ -40,10 +40,6 @@ def parse_number_of_warnings_subprocess(result: FullOutputSubprocessResult) -> i return num -def parse_number_of_freshness_warnings_subprocess(result: FullOutputSubprocessResult) -> int: - return sum(1 for line in result.full_output if DBT_FRESHNESS_WARN_MSG in line) - - def parse_number_of_warnings_dbt_runner(result: dbtRunnerResult) -> int: """Parses a dbt runner result and returns the number of warnings found. This only works for dbtRunnerResult from invoking dbt build, compile, run, seed, snapshot, test, or run-operation. @@ -55,7 +51,8 @@ def parse_number_of_warnings_dbt_runner(result: dbtRunnerResult) -> int: return num -def extract_freshness_warn_issue(log_list: List[str]) -> Tuple[List[str], List[str]]: +def extract_freshness_warn_msg(result: FullOutputSubprocessResult) -> Tuple[List[str], List[str]]: + log_list = result.full_output test_names = [] test_results = [] diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 0fa8da23a..dee00114c 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -57,9 +57,8 @@ ) from cosmos.dbt.parser.output import ( extract_dbt_runner_issues, - extract_freshness_warn_issue, + extract_freshness_warn_msg, extract_log_issues, - parse_number_of_freshness_warnings_subprocess, parse_number_of_warnings_dbt_runner, parse_number_of_warnings_subprocess, ) @@ -712,17 +711,6 @@ def __init__(self, *args: Any, on_warning_callback: Callable[..., Any] | None = super().__init__(*args, **kwargs) self.on_warning_callback = on_warning_callback self.extract_issues: Callable[..., tuple[list[str], list[str]]] - self.parse_number_of_warnings: Callable[..., int] - - def _set_test_result_parsing_methods(self) -> None: - """Sets the extract_issues and parse_number_of_warnings methods based on the invocation mode.""" - if self.invocation_mode == InvocationMode.SUBPROCESS: - self.extract_issues = extract_freshness_warn_issue - self.parse_number_of_warnings = parse_number_of_freshness_warnings_subprocess - # TODO: FIXME - # elif self.invocation_mode == InvocationMode.DBT_RUNNER: - # self.extract_issues = extract_dbt_runner_issues - # self.parse_number_of_warnings = parse_number_of_warnings_dbt_runner def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, context: Context) -> None: """ @@ -732,7 +720,12 @@ def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, :param result: The result object from the build and run command. :param context: The original airflow context in which the build and run command was executed. """ - test_names, test_results = self.extract_issues(result.full_output) + if self.invocation_mode == InvocationMode.SUBPROCESS: + self.extract_issues = extract_freshness_warn_msg + elif self.invocation_mode == InvocationMode.DBT_RUNNER: + self.extract_issues = extract_dbt_runner_issues + + test_names, test_results = self.extract_issues(result) warning_context = dict(context) warning_context["test_names"] = test_names @@ -742,9 +735,7 @@ def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, def execute(self, context: Context) -> None: result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) - self._set_test_result_parsing_methods() - number_of_warnings = self.parse_number_of_warnings(result) # type: ignore - if self.on_warning_callback and number_of_warnings > 0: + if self.on_warning_callback: self._handle_warnings(result, context)