Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning callback on source freshness #1400

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def create_task_metadata(
use_task_group: bool = False,
source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE,
test_behavior: TestBehavior = TestBehavior.AFTER_ALL,
on_warning_callback: Callable[..., Any] | None = None,
) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.
Expand All @@ -176,17 +177,20 @@ def create_task_metadata(
:param dbt_dag_task_group_identifier: Identifier to refer to the DbtDAG or DbtTaskGroup in the DAG.
:param use_task_group: It determines whether to use the name as a prefix for the task id or not.
If it is False, then use the name as a prefix for the task id, otherwise do not.
:param on_warning_callback: A callback function called on warnings with additional Context variables “test_names”
and “test_results” of type List. This is param available for dbt test and dbt source freshness command.
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
"""
dbt_resource_to_class = create_dbt_resource_to_class(test_behavior)

args = {**args, **{"models": node.resource_name}}

if DbtResourceType(node.resource_type) in DEFAULT_DBT_RESOURCES and node.resource_type in dbt_resource_to_class:
extra_context = {
extra_context: dict[str, Any] = {
"dbt_node_config": node.context_dict,
"dbt_dag_task_group_identifier": dbt_dag_task_group_identifier,
}

if test_behavior == TestBehavior.BUILD and node.resource_type in SUPPORTED_BUILD_RESOURCES:
task_id = f"{node.name}_{node.resource_type.value}_build"
elif node.resource_type == DbtResourceType.MODEL:
Expand All @@ -195,6 +199,8 @@ def create_task_metadata(
else:
task_id = f"{node.name}_run"
elif node.resource_type == DbtResourceType.SOURCE:
args["on_warning_callback"] = on_warning_callback

if (source_rendering_behavior == SourceRenderingBehavior.NONE) or (
source_rendering_behavior == SourceRenderingBehavior.WITH_TESTS_OR_FRESHNESS
and node.has_freshness is False
Expand Down Expand Up @@ -262,6 +268,7 @@ def generate_task_or_group(
use_task_group=use_task_group,
source_rendering_behavior=source_rendering_behavior,
test_behavior=test_behavior,
on_warning_callback=on_warning_callback,
)

# In most cases, we'll map one DBT node to one Airflow task
Expand Down
17 changes: 17 additions & 0 deletions cosmos/dbt/parser/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

DBT_NO_TESTS_MSG = "Nothing to do"
DBT_WARN_MSG = "WARN"
DBT_FRESHNESS_WARN_MSG = "WARN freshness of"


def parse_number_of_warnings_subprocess(result: FullOutputSubprocessResult) -> int:
Expand Down Expand Up @@ -50,6 +51,22 @@ def parse_number_of_warnings_dbt_runner(result: dbtRunnerResult) -> int:
return num


def extract_freshness_warn_msg(result: FullOutputSubprocessResult) -> Tuple[List[str], List[str]]:
log_list = result.full_output

node_names = []
node_results = []

for line in log_list:

if DBT_FRESHNESS_WARN_MSG in line:
node_name = line.split(DBT_FRESHNESS_WARN_MSG)[1].split(" ")[1]
node_names.append(node_name)
node_results.append(line)

return node_names, node_results


def extract_log_issues(log_list: List[str]) -> Tuple[List[str], List[str]]:
"""
Extracts warning messages from the log list and returns them as a formatted string.
Expand Down
31 changes: 30 additions & 1 deletion cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
)
from cosmos.dbt.parser.output import (
extract_dbt_runner_issues,
extract_freshness_warn_msg,
extract_log_issues,
parse_number_of_warnings_dbt_runner,
parse_number_of_warnings_subprocess,
Expand Down Expand Up @@ -706,8 +707,36 @@ class DbtSourceLocalOperator(DbtSourceMixin, DbtLocalBaseOperator):
Executes a dbt source freshness command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, *args: Any, on_warning_callback: Callable[..., Any] | None = None, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.on_warning_callback = on_warning_callback
self.extract_issues: Callable[..., tuple[list[str], list[str]]]

def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, context: Context) -> None:
"""
Handles warnings by extracting log issues, creating additional context, and calling the
on_warning_callback with the updated context.

:param result: The result object from the build and run command.
:param context: The original airflow context in which the build and run command was executed.
"""
if self.invocation_mode == InvocationMode.SUBPROCESS:
self.extract_issues = extract_freshness_warn_msg
elif self.invocation_mode == InvocationMode.DBT_RUNNER:
self.extract_issues = extract_dbt_runner_issues

test_names, test_results = self.extract_issues(result)

warning_context = dict(context)
warning_context["test_names"] = test_names
warning_context["test_results"] = test_results

self.on_warning_callback and self.on_warning_callback(warning_context)

def execute(self, context: Context) -> None:
result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())
if self.on_warning_callback:
self._handle_warnings(result, context)


class DbtRunLocalOperator(DbtRunMixin, DbtLocalBaseOperator):
Expand Down
4 changes: 4 additions & 0 deletions dev/dags/example_source_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
),
)

# [START cosmos_source_node_example]

source_rendering_dag = DbtDag(
# dbt/cosmos-specific parameters
project_config=ProjectConfig(
Expand All @@ -40,4 +42,6 @@
catchup=False,
dag_id="source_rendering_dag",
default_args={"retries": 2},
on_warning_callback=lambda context: print(context),
)
# [END cosmos_source_node_example]
13 changes: 13 additions & 0 deletions docs/configuration/source-nodes-rendering.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,16 @@ Example:
source_rendering_behavior=SourceRenderingBehavior.WITH_TESTS_OR_FRESHNESS,
)
)


on_warning_callback Callback
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The ``on_warning_callback`` is a callback parameter available on the ``DbtSourceLocalOperator``. This callback is triggered when a warning occurs during the execution of the ``dbt source freshness`` command. The callback accepts the task context, which includes additional parameters: test_names and test_results

Example:

.. literalinclude:: ../../dev/dags/example_source_rendering.py/
:language: python
:start-after: [START cosmos_source_node_example]
:end-before: [END cosmos_source_node_example]
22 changes: 22 additions & 0 deletions tests/dbt/parser/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

from cosmos.dbt.parser.output import (
extract_dbt_runner_issues,
extract_freshness_warn_msg,
extract_log_issues,
parse_number_of_warnings_dbt_runner,
parse_number_of_warnings_subprocess,
)
from cosmos.hooks.subprocess import FullOutputSubprocessResult


@pytest.mark.parametrize(
Expand Down Expand Up @@ -112,3 +114,23 @@ def test_extract_dbt_runner_issues_with_status_levels():

assert node_names == ["node1", "node2"]
assert node_results == ["An error message", "A failure message"]


def test_extract_freshness_warn_msg():
result = FullOutputSubprocessResult(
full_output=[
"Info: some other log message",
"INFO - 11:50:42 1 of 1 WARN freshness of postgres_db.raw_orders ................................ [WARN in 0.01s]",
"INFO - 11:50:42",
"INFO - 11:50:42 Finished running 1 source in 0 hours 0 minutes and 0.04 seconds (0.04s).",
"INFO - 11:50:42 Done.",
],
output="INFO - 11:50:42 Done.",
exit_code=0,
)
node_names, node_results = extract_freshness_warn_msg(result)

assert node_names == ["postgres_db.raw_orders"]
assert node_results == [
"INFO - 11:50:42 1 of 1 WARN freshness of postgres_db.raw_orders ................................ [WARN in 0.01s]"
]
32 changes: 32 additions & 0 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,38 @@ def test_store_freshness_not_store_compiled_sql(mock_context, mock_session):
assert instance.freshness == ""


@pytest.mark.parametrize(
"invocation_mode, expected_extract_function",
[
(InvocationMode.SUBPROCESS, "extract_freshness_warn_msg"),
(InvocationMode.DBT_RUNNER, "extract_dbt_runner_issues"),
],
)
def test_handle_warnings(invocation_mode, expected_extract_function, mock_context):
result = MagicMock()

instance = DbtSourceLocalOperator(
task_id="test",
profile_config=None,
project_dir="my/dir",
on_warning_callback=lambda context: print(context),
invocation_mode=invocation_mode,
)

with patch(f"cosmos.operators.local.{expected_extract_function}") as mock_extract_issues, patch.object(
instance, "on_warning_callback"
) as mock_on_warning_callback:
mock_extract_issues.return_value = (["test_name1", "test_name2"], ["test_name1", "test_name2"])

instance._handle_warnings(result, mock_context)

mock_extract_issues.assert_called_once_with(result)

mock_on_warning_callback.assert_called_once_with(
{**mock_context, "test_names": ["test_name1", "test_name2"], "test_results": ["test_name1", "test_name2"]}
)


def test_dbt_compile_local_operator_initialisation():
operator = DbtCompileLocalOperator(
task_id="fake-task",
Expand Down
Loading