Skip to content

Commit

Permalink
Fix on_warning_callback issue
Browse files Browse the repository at this point in the history
Since 1.1.0, the on_warning_callback functionality no longer works, it worked on 1.0.5

Closes: #549
  • Loading branch information
tatiana committed Sep 26, 2023
1 parent d5ba070 commit eebf89a
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 6 deletions.
15 changes: 10 additions & 5 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@
FullOutputSubprocessHook,
FullOutputSubprocessResult,
)
from cosmos.dbt.parser.output import (
extract_log_issues,
)
from cosmos.dbt.parser.output import extract_log_issues, parse_output

logger = get_logger(__name__)

Expand Down Expand Up @@ -350,11 +348,12 @@ def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> Ope
job_facets=job_facets,
)

def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> None:
def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> FullOutputSubprocessResult:
dbt_cmd, env = self.build_cmd(context=context, cmd_flags=cmd_flags)
dbt_cmd = dbt_cmd or []
result = self.run_command(cmd=dbt_cmd, env=env, context=context)
logger.info(result.output)
return result

def execute(self, context: Context) -> None:
self.build_and_run_cmd(context=context)
Expand Down Expand Up @@ -472,7 +471,6 @@ def _should_run_tests(
:param result: The output from the build and run command.
"""

return self.on_warning_callback is not None and no_tests_message not in result.output

def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) -> None:
Expand All @@ -492,6 +490,13 @@ def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context)
if self.on_warning_callback:
self.on_warning_callback(warning_context)

def execute(self, context: Context) -> None:
result = self.build_and_run_cmd(context=context)
if self._should_run_tests(result):
warnings = parse_output(result, "WARN")
if warnings > 0:
self._handle_warnings(result, context)


class DbtRunOperationLocalOperator(DbtLocalBaseOperator):
"""
Expand Down
44 changes: 43 additions & 1 deletion tests/operators/test_local.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import shutil
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -28,6 +31,8 @@


DBT_PROJ_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop"
SCHEMA_FAILING_TEST = Path(__file__).parent.parent / "sample/schema_failing_test.yml"


profile_config = ProfileConfig(
profile_name="default",
Expand All @@ -45,6 +50,18 @@
)


@pytest.fixture
def failing_test_dbt_project(tmp_path):
tmp_dir = tempfile.TemporaryDirectory()
tmp_dir_path = Path(tmp_dir.name) / "jaffle_shop"
shutil.copytree(DBT_PROJ_DIR, tmp_dir_path)
target_schema = tmp_dir_path / "models/schema.yml"
os.remove(target_schema)
shutil.copy(SCHEMA_FAILING_TEST, target_schema)
yield tmp_dir_path
tmp_dir.cleanup()


def test_dbt_base_operator_add_global_flags() -> None:
dbt_base_operator = DbtLocalBaseOperator(
profile_config=profile_config,
Expand Down Expand Up @@ -175,14 +192,39 @@ def test_run_operator_dataset_inlets_and_outlets():
dbt_cmd_flags=["--models", "stg_customers"],
install_deps=True,
)
run_operator
run_operator >> test_operator
run_test_dag(dag)
assert run_operator.inlets == []
assert run_operator.outlets == [Dataset(uri="postgres://0.0.0.0:5432/postgres.public.stg_customers", extra=None)]
assert test_operator.inlets == [Dataset(uri="postgres://0.0.0.0:5432/postgres.public.stg_customers", extra=None)]
assert test_operator.outlets == []


@pytest.mark.integration
def test_run_test_operator_with_callback(failing_test_dbt_project):
on_warning_callback = MagicMock()

with DAG("test-id-2", start_date=datetime(2022, 1, 1)) as dag:
run_operator = DbtRunLocalOperator(
profile_config=real_profile_config,
project_dir=failing_test_dbt_project,
task_id="run",
dbt_cmd_flags=["--models", "orders"],
install_deps=True,
)
test_operator = DbtTestLocalOperator(
profile_config=real_profile_config,
project_dir=failing_test_dbt_project,
task_id="test",
dbt_cmd_flags=["--models", "orders"],
install_deps=True,
on_warning_callback=on_warning_callback,
)
run_operator >> test_operator
run_test_dag(dag)
assert on_warning_callback.called


@pytest.mark.integration
def test_run_operator_emits_events():
class MockRun:
Expand Down
83 changes: 83 additions & 0 deletions tests/sample/schema_failing_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
version: 2

models:
- name: customers
description: This table has basic information about a customer, as well as some derived facts based on a customer's orders

columns:
- name: customer_id
description: This is a unique identifier for a customer
tests:
- unique
- not_null

- name: first_name
description: Customer's first name. PII.

- name: last_name
description: Customer's last name. PII.

- name: first_order
description: Date (UTC) of a customer's first order

- name: most_recent_order
description: Date (UTC) of a customer's most recent order

- name: number_of_orders
description: Count of the number of orders a customer has placed

- name: total_order_amount
description: Total value (AUD) of a customer's orders

- name: orders
description: This table has basic information about orders, as well as some derived facts based on payments

columns:
- name: order_id
tests:
- unique
- not_null
description: This is a unique identifier for an order

- name: customer_id
description: Foreign key to the customers table
tests:
- not_null
- relationships:
to: ref('customers')
field: customer_id

- name: order_date
description: Date (UTC) that the order was placed

- name: status
description: '{{ doc("orders_status") }}'
tests:
- accepted_values:
# this test will fail, since this column has more values
values: ['placed']

- name: amount
description: Total amount (AUD) of the order
tests:
- not_null

- name: credit_card_amount
description: Amount of the order (AUD) paid for by credit card
tests:
- not_null

- name: coupon_amount
description: Amount of the order (AUD) paid for by coupon
tests:
- not_null

- name: bank_transfer_amount
description: Amount of the order (AUD) paid for by bank transfer
tests:
- not_null

- name: gift_card_amount
description: Amount of the order (AUD) paid for by gift card
tests:
- not_null

0 comments on commit eebf89a

Please sign in to comment.