diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 4888583bb..6fcc4293a 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -40,9 +40,7 @@ FullOutputSubprocessHook, FullOutputSubprocessResult, ) -from cosmos.dbt.parser.output import ( - extract_log_issues, -) +from cosmos.dbt.parser.output import extract_log_issues logger = get_logger(__name__) @@ -350,14 +348,15 @@ def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> Ope job_facets=job_facets, ) - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> None: + def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> FullOutputSubprocessResult: dbt_cmd, env = self.build_cmd(context=context, cmd_flags=cmd_flags) dbt_cmd = dbt_cmd or [] result = self.run_command(cmd=dbt_cmd, env=env, context=context) - logger.info(result.output) + return result def execute(self, context: Context) -> None: - self.build_and_run_cmd(context=context) + result = self.build_and_run_cmd(context=context) + logger.info(result.output) def on_kill(self) -> None: if self.cancel_query_on_kill: @@ -461,24 +460,10 @@ def __init__( self.base_cmd = ["test"] self.on_warning_callback = on_warning_callback - def _should_run_tests( - self, - result: FullOutputSubprocessResult, - no_tests_message: str = "Nothing to do", - ) -> bool: - """ - Check if any tests are defined to run in the DAG. If tests are defined - and on_warning_callback is set, then function returns True. - - :param result: The output from the build and run command. - """ - - return self.on_warning_callback is not None and no_tests_message not in result.output - def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) -> None: """ - Handles warnings by extracting log issues, creating additional context, and calling the - on_warning_callback with the updated context. + Handles warnings by extracting log issues, creating additional context, and calling the + on_warning_callback with the updated context. :param result: The result object from the build and run command. :param context: The original airflow context in which the build and run command was executed. @@ -489,8 +474,15 @@ def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) warning_context["test_names"] = test_names warning_context["test_results"] = test_results + self.on_warning_callback(warning_context) + + def execute(self, context: Context) -> None: + result = self.build_and_run_cmd(context=context) + if self.on_warning_callback: - self.on_warning_callback(warning_context) + self._handle_warnings(result, context) + + logger.info(result.output) class DbtRunOperationLocalOperator(DbtLocalBaseOperator): diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 94b0f8e27..36462c0b6 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -302,3 +302,40 @@ def test_operator_execute_without_flags(mock_build_and_run_cmd, operator_class): ) task.execute(context={}) mock_build_and_run_cmd.assert_called_once_with(context={}) + + +@pytest.mark.integration +def test_dbt_test_local_operator_on_warning_callback(): + on_warning_callback_mock = MagicMock() + + with DAG("my-test-dag", start_date=datetime(2022, 1, 1)) as dag: + test_operator = DbtTestLocalOperator( + profile_config=real_profile_config, + project_dir=DBT_PROJ_DIR, + task_id="my-task", + install_deps=True, + dbt_cmd_flags=["--models", "stg_customers"], + on_warning_callback=on_warning_callback_mock, + ) + test_operator + + run_test_dag(dag) + on_warning_callback_mock.assert_called_once() + + +@pytest.mark.integration +def test_on_warning_callback_not_triggered(): + on_warning_callback_mock = MagicMock() + + with DAG("my-test-dag", start_date=datetime(2022, 1, 1)) as dag: + test_operator = DbtTestLocalOperator( + profile_config=real_profile_config, + project_dir=DBT_PROJ_DIR, + install_deps=True, + dbt_cmd_flags=["--models", "stg_customers"], + task_id="my-task", + ) + test_operator + + run_test_dag(dag) + on_warning_callback_mock.assert_not_called()