Skip to content

Commit

Permalink
Add warning callback on source freshness
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro committed Dec 18, 2024
1 parent 9820935 commit f658f4a
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 2 deletions.
9 changes: 8 additions & 1 deletion cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def create_task_metadata(
use_task_group: bool = False,
source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE,
test_behavior: TestBehavior = TestBehavior.AFTER_ALL,
on_warning_callback: Callable[..., Any] | None = None,
) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.
Expand All @@ -176,17 +177,19 @@ def create_task_metadata(
:param dbt_dag_task_group_identifier: Identifier to refer to the DbtDAG or DbtTaskGroup in the DAG.
:param use_task_group: It determines whether to use the name as a prefix for the task id or not.
If it is False, then use the name as a prefix for the task id, otherwise do not.
:param on_warning_callback:
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
"""
dbt_resource_to_class = create_dbt_resource_to_class(test_behavior)

args = {**args, **{"models": node.resource_name}}

if DbtResourceType(node.resource_type) in DEFAULT_DBT_RESOURCES and node.resource_type in dbt_resource_to_class:
extra_context = {
extra_context: dict[str, Any] = {
"dbt_node_config": node.context_dict,
"dbt_dag_task_group_identifier": dbt_dag_task_group_identifier,
}

if test_behavior == TestBehavior.BUILD and node.resource_type in SUPPORTED_BUILD_RESOURCES:
task_id = f"{node.name}_{node.resource_type.value}_build"
elif node.resource_type == DbtResourceType.MODEL:
Expand All @@ -195,6 +198,9 @@ def create_task_metadata(
else:
task_id = f"{node.name}_run"
elif node.resource_type == DbtResourceType.SOURCE:
# if on_warning_callback is not None:
extra_context["on_warning_callback"] = on_warning_callback

if (source_rendering_behavior == SourceRenderingBehavior.NONE) or (
source_rendering_behavior == SourceRenderingBehavior.WITH_TESTS_OR_FRESHNESS
and node.has_freshness is False
Expand Down Expand Up @@ -262,6 +268,7 @@ def generate_task_or_group(
use_task_group=use_task_group,
source_rendering_behavior=source_rendering_behavior,
test_behavior=test_behavior,
on_warning_callback=on_warning_callback,
)

# In most cases, we'll map one DBT node to one Airflow task
Expand Down
6 changes: 6 additions & 0 deletions cosmos/core/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None)
for k, v in task.airflow_task_config.items():
task_kwargs[k] = v

# Set the on_warning_callback of source node in task_kwargs
on_warning_callback = task.extra_context.get("on_warning_callback")
if on_warning_callback is not None:
task_kwargs["on_warning_callback"] = on_warning_callback
del task.extra_context["on_warning_callback"]

airflow_task = Operator(
task_id=task.id,
dag=dag,
Expand Down
20 changes: 20 additions & 0 deletions cosmos/dbt/parser/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

DBT_NO_TESTS_MSG = "Nothing to do"
DBT_WARN_MSG = "WARN"
DBT_FRESHNESS_WARN_MSG = "WARN freshness of"


def parse_number_of_warnings_subprocess(result: FullOutputSubprocessResult) -> int:
Expand Down Expand Up @@ -39,6 +40,10 @@ def parse_number_of_warnings_subprocess(result: FullOutputSubprocessResult) -> i
return num


def parse_number_of_freshness_warnings_subprocess(result: FullOutputSubprocessResult) -> int:
return sum(1 for line in result.full_output if DBT_FRESHNESS_WARN_MSG in line)


def parse_number_of_warnings_dbt_runner(result: dbtRunnerResult) -> int:
"""Parses a dbt runner result and returns the number of warnings found. This only works for dbtRunnerResult
from invoking dbt build, compile, run, seed, snapshot, test, or run-operation.
Expand All @@ -50,6 +55,21 @@ def parse_number_of_warnings_dbt_runner(result: dbtRunnerResult) -> int:
return num


def extract_freshness_warn_issue(log_list: List[str]) -> Tuple[List[str], List[str]]:

test_names = []
test_results = []

for line in log_list:

if DBT_FRESHNESS_WARN_MSG in line:
test_name = line.split(DBT_FRESHNESS_WARN_MSG)[1].split(" ")[1]
test_names.append(test_name)
test_results.append(line)

return test_names, test_results


def extract_log_issues(log_list: List[str]) -> Tuple[List[str], List[str]]:
"""
Extracts warning messages from the log list and returns them as a formatted string.
Expand Down
40 changes: 39 additions & 1 deletion cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@
)
from cosmos.dbt.parser.output import (
extract_dbt_runner_issues,
extract_freshness_warn_issue,
extract_log_issues,
parse_number_of_freshness_warnings_subprocess,
parse_number_of_warnings_dbt_runner,
parse_number_of_warnings_subprocess,
)
Expand Down Expand Up @@ -706,8 +708,44 @@ class DbtSourceLocalOperator(DbtSourceMixin, DbtLocalBaseOperator):
Executes a dbt source freshness command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, *args: Any, on_warning_callback: Callable[..., Any] | None = None, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.on_warning_callback = on_warning_callback
self.extract_issues: Callable[..., tuple[list[str], list[str]]]
self.parse_number_of_warnings: Callable[..., int]

def _set_test_result_parsing_methods(self) -> None:
"""Sets the extract_issues and parse_number_of_warnings methods based on the invocation mode."""
if self.invocation_mode == InvocationMode.SUBPROCESS:
self.extract_issues = extract_freshness_warn_issue
self.parse_number_of_warnings = parse_number_of_freshness_warnings_subprocess
# TODO: FIXME
# elif self.invocation_mode == InvocationMode.DBT_RUNNER:
# self.extract_issues = extract_dbt_runner_issues
# self.parse_number_of_warnings = parse_number_of_warnings_dbt_runner

def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, context: Context) -> None:
"""
Handles warnings by extracting log issues, creating additional context, and calling the
on_warning_callback with the updated context.
:param result: The result object from the build and run command.
:param context: The original airflow context in which the build and run command was executed.
"""
test_names, test_results = self.extract_issues(result.full_output)

warning_context = dict(context)
warning_context["test_names"] = test_names
warning_context["test_results"] = test_results

self.on_warning_callback and self.on_warning_callback(warning_context)

def execute(self, context: Context) -> None:
result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())
self._set_test_result_parsing_methods()
number_of_warnings = self.parse_number_of_warnings(result) # type: ignore
if self.on_warning_callback and number_of_warnings > 0:
self._handle_warnings(result, context)


class DbtRunLocalOperator(DbtRunMixin, DbtLocalBaseOperator):
Expand Down

0 comments on commit f658f4a

Please sign in to comment.