Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing existing issues with DbtTestLocalOperator on_warning_callback #556

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@
FullOutputSubprocessHook,
FullOutputSubprocessResult,
)
from cosmos.dbt.parser.output import (
extract_log_issues,
)
from cosmos.dbt.parser.output import extract_log_issues

logger = get_logger(__name__)

Expand Down Expand Up @@ -350,14 +348,15 @@ def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> Ope
job_facets=job_facets,
)

def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> None:
def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> FullOutputSubprocessResult:
dbt_cmd, env = self.build_cmd(context=context, cmd_flags=cmd_flags)
dbt_cmd = dbt_cmd or []
result = self.run_command(cmd=dbt_cmd, env=env, context=context)
logger.info(result.output)
return result

def execute(self, context: Context) -> None:
self.build_and_run_cmd(context=context)
result = self.build_and_run_cmd(context=context)
logger.info(result.output)

def on_kill(self) -> None:
if self.cancel_query_on_kill:
Expand Down Expand Up @@ -461,24 +460,10 @@ def __init__(
self.base_cmd = ["test"]
self.on_warning_callback = on_warning_callback

def _should_run_tests(
self,
result: FullOutputSubprocessResult,
no_tests_message: str = "Nothing to do",
) -> bool:
"""
Check if any tests are defined to run in the DAG. If tests are defined
and on_warning_callback is set, then function returns True.

:param result: The output from the build and run command.
"""

return self.on_warning_callback is not None and no_tests_message not in result.output

def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) -> None:
"""
Handles warnings by extracting log issues, creating additional context, and calling the
on_warning_callback with the updated context.
Handles warnings by extracting log issues, creating additional context, and calling the
on_warning_callback with the updated context.

:param result: The result object from the build and run command.
:param context: The original airflow context in which the build and run command was executed.
Expand All @@ -489,8 +474,15 @@ def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context)
warning_context["test_names"] = test_names
warning_context["test_results"] = test_results

self.on_warning_callback(warning_context)

def execute(self, context: Context) -> None:
result = self.build_and_run_cmd(context=context)

if self.on_warning_callback:
self.on_warning_callback(warning_context)
self._handle_warnings(result, context)

logger.info(result.output)


class DbtRunOperationLocalOperator(DbtLocalBaseOperator):
Expand Down
37 changes: 37 additions & 0 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,40 @@ def test_operator_execute_without_flags(mock_build_and_run_cmd, operator_class):
)
task.execute(context={})
mock_build_and_run_cmd.assert_called_once_with(context={})


@pytest.mark.integration
def test_dbt_test_local_operator_on_warning_callback():
on_warning_callback_mock = MagicMock()

with DAG("my-test-dag", start_date=datetime(2022, 1, 1)) as dag:
test_operator = DbtTestLocalOperator(
profile_config=real_profile_config,
project_dir=DBT_PROJ_DIR,
task_id="my-task",
install_deps=True,
dbt_cmd_flags=["--models", "stg_customers"],
on_warning_callback=on_warning_callback_mock,
)
test_operator

run_test_dag(dag)
on_warning_callback_mock.assert_called_once()


@pytest.mark.integration
def test_on_warning_callback_not_triggered():
on_warning_callback_mock = MagicMock()

with DAG("my-test-dag", start_date=datetime(2022, 1, 1)) as dag:
test_operator = DbtTestLocalOperator(
profile_config=real_profile_config,
project_dir=DBT_PROJ_DIR,
install_deps=True,
dbt_cmd_flags=["--models", "stg_customers"],
task_id="my-task",
)
test_operator

run_test_dag(dag)
on_warning_callback_mock.assert_not_called()
Loading