diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 4888583bb..36c676017 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -40,9 +40,7 @@ FullOutputSubprocessHook, FullOutputSubprocessResult, ) -from cosmos.dbt.parser.output import ( - extract_log_issues, -) +from cosmos.dbt.parser.output import extract_log_issues, parse_output logger = get_logger(__name__) @@ -350,11 +348,12 @@ def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> Ope job_facets=job_facets, ) - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> None: + def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> FullOutputSubprocessResult: dbt_cmd, env = self.build_cmd(context=context, cmd_flags=cmd_flags) dbt_cmd = dbt_cmd or [] result = self.run_command(cmd=dbt_cmd, env=env, context=context) logger.info(result.output) + return result def execute(self, context: Context) -> None: self.build_and_run_cmd(context=context) @@ -472,7 +471,6 @@ def _should_run_tests( :param result: The output from the build and run command. """ - return self.on_warning_callback is not None and no_tests_message not in result.output def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) -> None: @@ -492,6 +490,13 @@ def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) if self.on_warning_callback: self.on_warning_callback(warning_context) + def execute(self, context: Context) -> None: + result = self.build_and_run_cmd(context=context) + if self._should_run_tests(result): + warnings = parse_output(result, "WARN") + if warnings > 0: + self._handle_warnings(result, context) + class DbtRunOperationLocalOperator(DbtLocalBaseOperator): """ diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 94b0f8e27..223305fbe 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -1,3 +1,6 @@ +import os +import shutil +import tempfile from pathlib import Path from unittest.mock import MagicMock, patch @@ -28,6 +31,8 @@ DBT_PROJ_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop" +SCHEMA_FAILING_TEST = Path(__file__).parent.parent / "sample/schema_failing_test.yml" + profile_config = ProfileConfig( profile_name="default", @@ -45,6 +50,18 @@ ) +@pytest.fixture +def failing_test_dbt_project(tmp_path): + tmp_dir = tempfile.TemporaryDirectory() + tmp_dir_path = Path(tmp_dir.name) / "jaffle_shop" + shutil.copytree(DBT_PROJ_DIR, tmp_dir_path) + target_schema = tmp_dir_path / "models/schema.yml" + os.remove(target_schema) + shutil.copy(SCHEMA_FAILING_TEST, target_schema) + yield tmp_dir_path + tmp_dir.cleanup() + + def test_dbt_base_operator_add_global_flags() -> None: dbt_base_operator = DbtLocalBaseOperator( profile_config=profile_config, @@ -175,7 +192,7 @@ def test_run_operator_dataset_inlets_and_outlets(): dbt_cmd_flags=["--models", "stg_customers"], install_deps=True, ) - run_operator + run_operator >> test_operator run_test_dag(dag) assert run_operator.inlets == [] assert run_operator.outlets == [Dataset(uri="postgres://0.0.0.0:5432/postgres.public.stg_customers", extra=None)] @@ -183,6 +200,31 @@ def test_run_operator_dataset_inlets_and_outlets(): assert test_operator.outlets == [] +@pytest.mark.integration +def test_run_test_operator_with_callback(failing_test_dbt_project): + on_warning_callback = MagicMock() + + with DAG("test-id-2", start_date=datetime(2022, 1, 1)) as dag: + run_operator = DbtRunLocalOperator( + profile_config=real_profile_config, + project_dir=failing_test_dbt_project, + task_id="run", + dbt_cmd_flags=["--models", "orders"], + install_deps=True, + ) + test_operator = DbtTestLocalOperator( + profile_config=real_profile_config, + project_dir=failing_test_dbt_project, + task_id="test", + dbt_cmd_flags=["--models", "orders"], + install_deps=True, + on_warning_callback=on_warning_callback, + ) + run_operator >> test_operator + run_test_dag(dag) + assert on_warning_callback.called + + @pytest.mark.integration def test_run_operator_emits_events(): class MockRun: diff --git a/tests/sample/schema_failing_test.yml b/tests/sample/schema_failing_test.yml new file mode 100644 index 000000000..c75df8152 --- /dev/null +++ b/tests/sample/schema_failing_test.yml @@ -0,0 +1,83 @@ +version: 2 + +models: + - name: customers + description: This table has basic information about a customer, as well as some derived facts based on a customer's orders + + columns: + - name: customer_id + description: This is a unique identifier for a customer + tests: + - unique + - not_null + + - name: first_name + description: Customer's first name. PII. + + - name: last_name + description: Customer's last name. PII. + + - name: first_order + description: Date (UTC) of a customer's first order + + - name: most_recent_order + description: Date (UTC) of a customer's most recent order + + - name: number_of_orders + description: Count of the number of orders a customer has placed + + - name: total_order_amount + description: Total value (AUD) of a customer's orders + + - name: orders + description: This table has basic information about orders, as well as some derived facts based on payments + + columns: + - name: order_id + tests: + - unique + - not_null + description: This is a unique identifier for an order + + - name: customer_id + description: Foreign key to the customers table + tests: + - not_null + - relationships: + to: ref('customers') + field: customer_id + + - name: order_date + description: Date (UTC) that the order was placed + + - name: status + description: '{{ doc("orders_status") }}' + tests: + - accepted_values: + # this test will fail, since this column has more values + values: ['placed'] + + - name: amount + description: Total amount (AUD) of the order + tests: + - not_null + + - name: credit_card_amount + description: Amount of the order (AUD) paid for by credit card + tests: + - not_null + + - name: coupon_amount + description: Amount of the order (AUD) paid for by coupon + tests: + - not_null + + - name: bank_transfer_amount + description: Amount of the order (AUD) paid for by bank transfer + tests: + - not_null + + - name: gift_card_amount + description: Amount of the order (AUD) paid for by gift card + tests: + - not_null