Skip to content

Commit

Permalink
support env vars for render config for dbt ls parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
jbandoro committed Nov 18, 2023
1 parent e23a445 commit 8980bf4
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 7 deletions.
2 changes: 2 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class RenderConfig:
:param dbt_deps: Configure to run dbt deps when using dbt ls for dag parsing
:param node_converters: a dictionary mapping a ``DbtResourceType`` into a callable. Users can control how to render dbt nodes in Airflow. Only supported when using ``load_method=LoadMode.DBT_MANIFEST`` or ``LoadMode.DBT_LS``.
:param dbt_executable_path: The path to the dbt executable for dag generation. Defaults to dbt if available on the path. Mutually Exclusive with ProjectConfig.dbt_project_path
:param env_vars: A dictionary of environment variables for rendering. Only supported when using ``LoadMode.DBT_LS``.
:param dbt_project_path Configures the DBT project location accessible on the airflow controller for DAG rendering - Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM``
"""

Expand All @@ -53,6 +54,7 @@ class RenderConfig:
dbt_deps: bool = True
node_converters: dict[DbtResourceType, Callable[..., Any]] | None = None
dbt_executable_path: str | Path = get_system_dbt()
env_vars: dict[str, str] = field(default_factory=dict)
dbt_project_path: InitVar[str | Path | None] = None

project_path: Path | None = field(init=False)
Expand Down
6 changes: 4 additions & 2 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LoadMode,
)
from cosmos.dbt.parser.project import LegacyDbtProject
from cosmos.dbt.project import create_symlinks
from cosmos.dbt.project import create_symlinks, environ
from cosmos.dbt.selector import select_nodes
from cosmos.log import get_logger

Expand Down Expand Up @@ -234,7 +234,9 @@ def load_via_dbt_ls(self) -> None:
tmpdir_path = Path(tmpdir)
create_symlinks(self.render_config.project_path, tmpdir_path)

with self.profile_config.ensure_profile(use_mock_values=True) as profile_values:
with self.profile_config.ensure_profile(use_mock_values=True) as profile_values, environ(
self.render_config.env_vars
):
(profile_path, env_vars) = profile_values
env = os.environ.copy()
env.update(env_vars)
Expand Down
19 changes: 19 additions & 0 deletions cosmos/dbt/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
DBT_LOG_DIR_NAME,
DBT_TARGET_DIR_NAME,
)
from contextlib import contextmanager
from typing import Generator


def create_symlinks(project_path: Path, tmp_dir: Path) -> None:
Expand All @@ -12,3 +14,20 @@ def create_symlinks(project_path: Path, tmp_dir: Path) -> None:
for child_name in os.listdir(project_path):
if child_name not in ignore_paths:
os.symlink(project_path / child_name, tmp_dir / child_name)


@contextmanager
def environ(env_vars: dict[str, str]) -> Generator[None, None, None]:
"""Temporarily set environment variables inside the context manager and restore
when exiting.
"""
original_env = {key: os.getenv(key) for key in env_vars}
os.environ.update(env_vars)
try:
yield
finally:
for key, value in original_env.items():
if value is None:
del os.environ[key]
else:
os.environ[key] = value
7 changes: 4 additions & 3 deletions dev/dags/example_cosmos_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))

os.environ["DBT_SQLITE_PATH"] = str(DEFAULT_DBT_ROOT_PATH / "data")
DBT_SQLITE_PATH = str(DEFAULT_DBT_ROOT_PATH / "data")


profile_config = ProfileConfig(
Expand Down Expand Up @@ -62,7 +62,8 @@ def convert_exposure(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs):
node_converters={
DbtResourceType("source"): convert_source, # known dbt node type to Cosmos (part of DbtResourceType)
DbtResourceType("exposure"): convert_exposure, # dbt node type new to Cosmos (will be added to DbtResourceType)
}
},
env_vars={"DBT_SQLITE_PATH": DBT_SQLITE_PATH},
)


Expand All @@ -73,7 +74,7 @@ def convert_exposure(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs):
),
profile_config=profile_config,
render_config=render_config,
operator_args={"append_env": True},
operator_args={"env": {"DBT_SQLITE_PATH": DBT_SQLITE_PATH}},
# normal dag parameters
schedule_interval="@daily",
start_date=datetime(2023, 1, 1),
Expand Down
6 changes: 5 additions & 1 deletion tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,11 @@ def test_load_via_dbt_ls_with_sources(load_method):
dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name,
manifest_path=SAMPLE_MANIFEST_SOURCE if load_method == "load_from_dbt_manifest" else None,
),
render_config=RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name, dbt_deps=False),
render_config=RenderConfig(
dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name,
dbt_deps=False,
env_vars={"DBT_SQLITE_PATH": str(DBT_PROJECTS_ROOT_DIR / "data")},
),
execution_config=ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name),
profile_config=ProfileConfig(
profile_name="simple",
Expand Down
22 changes: 21 additions & 1 deletion tests/dbt/test_project.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pathlib import Path
from cosmos.dbt.project import create_symlinks
from cosmos.dbt.project import create_symlinks, environ
import os
from unittest.mock import patch

DBT_PROJECTS_ROOT_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt"

Expand All @@ -13,3 +15,21 @@ def test_create_symlinks(tmp_path):
for child in tmp_dir.iterdir():
assert child.is_symlink()
assert child.name not in ("logs", "target", "profiles.yml", "dbt_packages")


@patch.dict(os.environ, {"ORIGINAL_VAR": "value"})
def test_environ_context_manager():
# Define the expected environment variables
expected_env_vars = {"DBT_PROJECTS_ROOT_DIR": "/path/to/dbt/projects", "DBT_LOG_LEVEL": "debug"}
# Use the environ context manager
with environ(expected_env_vars):
# Check if the environment variables are set correctly
for key, value in expected_env_vars.items():
assert value == os.environ.get(key)
# Check if the original environment variables are still set
assert "value" == os.environ.get("ORIGINAL_VAR")
# Check if the environment variables are unset after exiting the context manager
for key in expected_env_vars.keys():
assert os.environ.get(key) is None
# Check if the original environment variables are still set
assert "value" == os.environ.get("ORIGINAL_VAR")

0 comments on commit 8980bf4

Please sign in to comment.