Skip to content

Commit

Permalink
Merge branch 'main' into address-pr-1347-review-comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti authored Dec 20, 2024
2 parents 1a1c7d8 + 3b92421 commit df3bbc5
Show file tree
Hide file tree
Showing 33 changed files with 2,096 additions and 1,223 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ jobs:
matrix:
python-version: ["3.11"]
airflow-version: ["2.8"]
num-models: [1, 10, 50, 100, 500, 1000]
num-models: [1, 10, 50, 100, 500]
services:
postgres:
image: postgres
Expand Down
57 changes: 45 additions & 12 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,27 @@ def create_test_task_metadata(
)


def _get_task_id_and_args(
node: DbtNode,
args: dict[str, Any],
use_task_group: bool,
normalize_task_id: Callable[..., Any] | None,
resource_suffix: str,
) -> tuple[str, dict[str, Any]]:
"""
Generate task ID and update args with display name if needed.
"""
args_update = args
if use_task_group:
task_id = resource_suffix
elif normalize_task_id:
task_id = normalize_task_id(node)
args_update["task_display_name"] = f"{node.name}_{resource_suffix}"
else:
task_id = f"{node.name}_{resource_suffix}"
return task_id, args_update


def create_dbt_resource_to_class(test_behavior: TestBehavior) -> dict[str, str]:
"""
Return the map from dbt node type to Cosmos class prefix that should be used
Expand Down Expand Up @@ -164,7 +185,9 @@ def create_task_metadata(
dbt_dag_task_group_identifier: str,
use_task_group: bool = False,
source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE,
normalize_task_id: Callable[..., Any] | None = None,
test_behavior: TestBehavior = TestBehavior.AFTER_ALL,
on_warning_callback: Callable[..., Any] | None = None,
) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.
Expand All @@ -176,43 +199,48 @@ def create_task_metadata(
:param dbt_dag_task_group_identifier: Identifier to refer to the DbtDAG or DbtTaskGroup in the DAG.
:param use_task_group: It determines whether to use the name as a prefix for the task id or not.
If it is False, then use the name as a prefix for the task id, otherwise do not.
:param on_warning_callback: A callback function called on warnings with additional Context variables “test_names”
and “test_results” of type List. This is param available for dbt test and dbt source freshness command.
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
"""
dbt_resource_to_class = create_dbt_resource_to_class(test_behavior)

args = {**args, **{"models": node.resource_name}}

if DbtResourceType(node.resource_type) in DEFAULT_DBT_RESOURCES and node.resource_type in dbt_resource_to_class:
extra_context = {
extra_context: dict[str, Any] = {
"dbt_node_config": node.context_dict,
"dbt_dag_task_group_identifier": dbt_dag_task_group_identifier,
}

if test_behavior == TestBehavior.BUILD and node.resource_type in SUPPORTED_BUILD_RESOURCES:
task_id = f"{node.name}_{node.resource_type.value}_build"
elif node.resource_type == DbtResourceType.MODEL:
if use_task_group:
task_id = "run"
else:
task_id = f"{node.name}_run"
task_id, args = _get_task_id_and_args(node, args, use_task_group, normalize_task_id, "run")
elif node.resource_type == DbtResourceType.SOURCE:
args["on_warning_callback"] = on_warning_callback

if (source_rendering_behavior == SourceRenderingBehavior.NONE) or (
source_rendering_behavior == SourceRenderingBehavior.WITH_TESTS_OR_FRESHNESS
and node.has_freshness is False
and node.has_test is False
):
return None
task_id = f"{node.name}_source"
args["select"] = f"source:{node.resource_name}"
args.pop("models")
if use_task_group is True:
task_id = node.resource_type.value
task_id, args = _get_task_id_and_args(node, args, use_task_group, normalize_task_id, "source")
if node.has_freshness is False and source_rendering_behavior == SourceRenderingBehavior.ALL:
# render sources without freshness as empty operators
return TaskMetadata(id=task_id, operator_class="airflow.operators.empty.EmptyOperator")
# empty operator does not accept custom parameters (e.g., profile_args). recreate the args.
if "task_display_name" in args:
args = {"task_display_name": args["task_display_name"]}
else:
args = {}
return TaskMetadata(id=task_id, operator_class="airflow.operators.empty.EmptyOperator", arguments=args)
else:
task_id = f"{node.name}_{node.resource_type.value}"
if use_task_group is True:
task_id = node.resource_type.value
task_id, args = _get_task_id_and_args(
node, args, use_task_group, normalize_task_id, node.resource_type.value
)

task_metadata = TaskMetadata(
id=task_id,
Expand Down Expand Up @@ -244,6 +272,7 @@ def generate_task_or_group(
source_rendering_behavior: SourceRenderingBehavior,
test_indirect_selection: TestIndirectSelection,
on_warning_callback: Callable[..., Any] | None,
normalize_task_id: Callable[..., Any] | None = None,
**kwargs: Any,
) -> BaseOperator | TaskGroup | None:
task_or_group: BaseOperator | TaskGroup | None = None
Expand All @@ -261,7 +290,9 @@ def generate_task_or_group(
dbt_dag_task_group_identifier=_get_dbt_dag_task_group_identifier(dag, task_group),
use_task_group=use_task_group,
source_rendering_behavior=source_rendering_behavior,
normalize_task_id=normalize_task_id,
test_behavior=test_behavior,
on_warning_callback=on_warning_callback,
)

# In most cases, we'll map one DBT node to one Airflow task
Expand Down Expand Up @@ -364,6 +395,7 @@ def build_airflow_graph(
node_converters = render_config.node_converters or {}
test_behavior = render_config.test_behavior
source_rendering_behavior = render_config.source_rendering_behavior
normalize_task_id = render_config.normalize_task_id
tasks_map: dict[str, Union[TaskGroup, BaseOperator]] = {}
task_or_group: TaskGroup | BaseOperator

Expand All @@ -385,6 +417,7 @@ def build_airflow_graph(
source_rendering_behavior=source_rendering_behavior,
test_indirect_selection=test_indirect_selection,
on_warning_callback=on_warning_callback,
normalize_task_id=normalize_task_id,
node=node,
)
if task_or_group is not None:
Expand Down
3 changes: 3 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class RenderConfig:
:param dbt_ls_path: Configures the location of an output of ``dbt ls``. Required when using ``load_method=LoadMode.DBT_LS_FILE``.
:param enable_mock_profile: Allows to enable/disable mocking profile. Enabled by default. Mock profiles are useful for parsing Cosmos DAGs in the CI, but should be disabled to benefit from partial parsing (since Cosmos 1.4).
:param source_rendering_behavior: Determines how source nodes are rendered when using cosmos default source node rendering (ALL, NONE, WITH_TESTS_OR_FRESHNESS). Defaults to "NONE" (since Cosmos 1.6).
:param airflow_vars_to_purge_dbt_ls_cache: Specify Airflow variables that will affect the LoadMode.DBT_LS cache.
:param normalize_task_id: A callable that takes a dbt node as input and returns the task ID. This allows users to assign a custom node ID separate from the display name.
"""

emit_datasets: bool = True
Expand All @@ -80,6 +82,7 @@ class RenderConfig:
enable_mock_profile: bool = True
source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE
airflow_vars_to_purge_dbt_ls_cache: list[str] = field(default_factory=list)
normalize_task_id: Callable[..., Any] | None = None

def __post_init__(self, dbt_project_path: str | Path | None) -> None:
if self.env_vars:
Expand Down
4 changes: 2 additions & 2 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def __init__(

validate_changed_config_paths(execution_config, project_config, render_config)

env_vars = copy.deepcopy(project_config.env_vars or operator_args.get("env"))
dbt_vars = copy.deepcopy(project_config.dbt_vars or operator_args.get("vars"))
env_vars = project_config.env_vars or operator_args.get("env")
dbt_vars = project_config.dbt_vars or operator_args.get("vars")

if execution_config.execution_mode != ExecutionMode.VIRTUALENV and execution_config.virtualenv_dir is not None:
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def run_command(command: list[str], tmp_dir: Path, env_vars: dict[str, str]) ->
)

if returncode or "Error" in stdout.replace("WarnErrorOptions", ""):
details = stderr or stdout
details = f"stderr: {stderr}\nstdout: {stdout}"
raise CosmosLoadDbtException(f"Unable to run {command} due to the error:\n{details}")

return stdout
Expand Down
17 changes: 17 additions & 0 deletions cosmos/dbt/parser/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

DBT_NO_TESTS_MSG = "Nothing to do"
DBT_WARN_MSG = "WARN"
DBT_FRESHNESS_WARN_MSG = "WARN freshness of"


def parse_number_of_warnings_subprocess(result: FullOutputSubprocessResult) -> int:
Expand Down Expand Up @@ -50,6 +51,22 @@ def parse_number_of_warnings_dbt_runner(result: dbtRunnerResult) -> int:
return num


def extract_freshness_warn_msg(result: FullOutputSubprocessResult) -> Tuple[List[str], List[str]]:
log_list = result.full_output

node_names = []
node_results = []

for line in log_list:

if DBT_FRESHNESS_WARN_MSG in line:
node_name = line.split(DBT_FRESHNESS_WARN_MSG)[1].split(" ")[1]
node_names.append(node_name)
node_results.append(line)

return node_names, node_results


def extract_log_issues(log_list: List[str]) -> Tuple[List[str], List[str]]:
"""
Extracts warning messages from the log list and returns them as a formatted string.
Expand Down
31 changes: 30 additions & 1 deletion cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
)
from cosmos.dbt.parser.output import (
extract_dbt_runner_issues,
extract_freshness_warn_msg,
extract_log_issues,
parse_number_of_warnings_dbt_runner,
parse_number_of_warnings_subprocess,
Expand Down Expand Up @@ -706,8 +707,36 @@ class DbtSourceLocalOperator(DbtSourceMixin, DbtLocalBaseOperator):
Executes a dbt source freshness command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, *args: Any, on_warning_callback: Callable[..., Any] | None = None, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.on_warning_callback = on_warning_callback
self.extract_issues: Callable[..., tuple[list[str], list[str]]]

def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, context: Context) -> None:
"""
Handles warnings by extracting log issues, creating additional context, and calling the
on_warning_callback with the updated context.
:param result: The result object from the build and run command.
:param context: The original airflow context in which the build and run command was executed.
"""
if self.invocation_mode == InvocationMode.SUBPROCESS:
self.extract_issues = extract_freshness_warn_msg
elif self.invocation_mode == InvocationMode.DBT_RUNNER:
self.extract_issues = extract_dbt_runner_issues

test_names, test_results = self.extract_issues(result)

warning_context = dict(context)
warning_context["test_names"] = test_names
warning_context["test_results"] = test_results

self.on_warning_callback and self.on_warning_callback(warning_context)

def execute(self, context: Context) -> None:
result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())
if self.on_warning_callback:
self._handle_warnings(result, context)


class DbtRunLocalOperator(DbtRunMixin, DbtLocalBaseOperator):
Expand Down
3 changes: 3 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .databricks.oauth import DatabricksOauthProfileMapping
from .databricks.token import DatabricksTokenProfileMapping
from .exasol.user_pass import ExasolUserPasswordProfileMapping
from .oracle.user_pass import OracleUserPasswordProfileMapping
from .postgres.user_pass import PostgresUserPasswordProfileMapping
from .redshift.user_pass import RedshiftUserPasswordProfileMapping
from .snowflake.user_encrypted_privatekey_env_variable import SnowflakeEncryptedPrivateKeyPemProfileMapping
Expand All @@ -34,6 +35,7 @@
GoogleCloudOauthProfileMapping,
DatabricksTokenProfileMapping,
DatabricksOauthProfileMapping,
OracleUserPasswordProfileMapping,
PostgresUserPasswordProfileMapping,
RedshiftUserPasswordProfileMapping,
SnowflakeUserPasswordProfileMapping,
Expand Down Expand Up @@ -77,6 +79,7 @@ def get_automatic_profile_mapping(
"DatabricksTokenProfileMapping",
"DatabricksOauthProfileMapping",
"DbtProfileConfigVars",
"OracleUserPasswordProfileMapping",
"PostgresUserPasswordProfileMapping",
"RedshiftUserPasswordProfileMapping",
"SnowflakeUserPasswordProfileMapping",
Expand Down
5 changes: 5 additions & 0 deletions cosmos/profiles/oracle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Oracle Airflow connection -> dbt profile mappings"""

from .user_pass import OracleUserPasswordProfileMapping

__all__ = ["OracleUserPasswordProfileMapping"]
89 changes: 89 additions & 0 deletions cosmos/profiles/oracle/user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Maps Airflow Oracle connections using user + password authentication to dbt profiles."""

from __future__ import annotations

import re
from typing import Any

from ..base import BaseProfileMapping


class OracleUserPasswordProfileMapping(BaseProfileMapping):
"""
Maps Airflow Oracle connections using user + password authentication to dbt profiles.
https://docs.getdbt.com/reference/warehouse-setups/oracle-setup
https://airflow.apache.org/docs/apache-airflow-providers-oracle/stable/connections/oracle.html
"""

airflow_connection_type: str = "oracle"
dbt_profile_type: str = "oracle"
is_community: bool = True

required_fields = [
"user",
"password",
]
secret_fields = [
"password",
]
airflow_param_mapping = {
"host": "host",
"port": "port",
"service": "extra.service_name",
"user": "login",
"password": "password",
"database": "extra.service_name",
"connection_string": "extra.dsn",
}

@property
def env_vars(self) -> dict[str, str]:
"""Set oracle thick mode."""
env_vars = super().env_vars
if self._get_airflow_conn_field("extra.thick_mode"):
env_vars["ORA_PYTHON_DRIVER_TYPE"] = "thick"
return env_vars

@property
def profile(self) -> dict[str, Any | None]:
"""Gets profile. The password is stored in an environment variable."""
profile = {
"protocol": "tcp",
"port": 1521,
**self.mapped_params,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

if "schema" not in profile and "user" in profile:
proxy = re.search(r"\[([^]]+)\]", profile["user"])
if proxy:
profile["schema"] = proxy.group(1)
else:
profile["schema"] = profile["user"]
if "schema" in self.profile_args:
profile["schema"] = self.profile_args["schema"]

return self.filter_null(profile)

@property
def mock_profile(self) -> dict[str, Any | None]:
"""Gets mock profile. Defaults port to 1521."""
profile_dict = {
"protocol": "tcp",
"port": 1521,
**super().mock_profile,
}

if "schema" not in profile_dict and "user" in profile_dict:
proxy = re.search(r"\[([^]]+)\]", profile_dict["user"])
if proxy:
profile_dict["schema"] = proxy.group(1)
else:
profile_dict["schema"] = profile_dict["user"]

user_defined_schema = self.profile_args.get("schema")
if user_defined_schema:
profile_dict["schema"] = user_defined_schema
return profile_dict
Loading

0 comments on commit df3bbc5

Please sign in to comment.