diff --git a/cosmos/config.py b/cosmos/config.py index a33e96830..262c5e57a 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -42,6 +42,7 @@ class RenderConfig: :param dbt_deps: Configure to run dbt deps when using dbt ls for dag parsing :param node_converters: a dictionary mapping a ``DbtResourceType`` into a callable. Users can control how to render dbt nodes in Airflow. Only supported when using ``load_method=LoadMode.DBT_MANIFEST`` or ``LoadMode.DBT_LS``. :param dbt_executable_path: The path to the dbt executable for dag generation. Defaults to dbt if available on the path. Mutually Exclusive with ProjectConfig.dbt_project_path + :param env_vars: A dictionary of environment variables for rendering. Only supported when using ``LoadMode.DBT_LS``. :param dbt_project_path Configures the DBT project location accessible on the airflow controller for DAG rendering - Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM`` """ @@ -53,6 +54,7 @@ class RenderConfig: dbt_deps: bool = True node_converters: dict[DbtResourceType, Callable[..., Any]] | None = None dbt_executable_path: str | Path = get_system_dbt() + env_vars: dict[str, str] = field(default_factory=dict) dbt_project_path: InitVar[str | Path | None] = None project_path: Path | None = field(init=False) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 40154308b..86c9fc772 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -21,7 +21,7 @@ LoadMode, ) from cosmos.dbt.parser.project import LegacyDbtProject -from cosmos.dbt.project import create_symlinks +from cosmos.dbt.project import create_symlinks, environ from cosmos.dbt.selector import select_nodes from cosmos.log import get_logger @@ -234,7 +234,9 @@ def load_via_dbt_ls(self) -> None: tmpdir_path = Path(tmpdir) create_symlinks(self.render_config.project_path, tmpdir_path) - with self.profile_config.ensure_profile(use_mock_values=True) as profile_values: + with self.profile_config.ensure_profile(use_mock_values=True) as profile_values, environ( + self.render_config.env_vars + ): (profile_path, env_vars) = profile_values env = os.environ.copy() env.update(env_vars) diff --git a/cosmos/dbt/project.py b/cosmos/dbt/project.py index 63f4fc007..bc4df944d 100644 --- a/cosmos/dbt/project.py +++ b/cosmos/dbt/project.py @@ -4,6 +4,8 @@ DBT_LOG_DIR_NAME, DBT_TARGET_DIR_NAME, ) +from contextlib import contextmanager +from typing import Generator def create_symlinks(project_path: Path, tmp_dir: Path) -> None: @@ -12,3 +14,20 @@ def create_symlinks(project_path: Path, tmp_dir: Path) -> None: for child_name in os.listdir(project_path): if child_name not in ignore_paths: os.symlink(project_path / child_name, tmp_dir / child_name) + + +@contextmanager +def environ(env_vars: dict[str, str]) -> Generator[None, None, None]: + """Temporarily set environment variables inside the context manager and restore + when exiting. + """ + original_env = {key: os.getenv(key) for key in env_vars} + os.environ.update(env_vars) + try: + yield + finally: + for key, value in original_env.items(): + if value is None: + del os.environ[key] + else: + os.environ[key] = value diff --git a/dev/dags/example_cosmos_sources.py b/dev/dags/example_cosmos_sources.py index 29c70db5a..157b3adb3 100644 --- a/dev/dags/example_cosmos_sources.py +++ b/dev/dags/example_cosmos_sources.py @@ -26,7 +26,7 @@ DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) -os.environ["DBT_SQLITE_PATH"] = str(DEFAULT_DBT_ROOT_PATH / "data") +DBT_SQLITE_PATH = str(DEFAULT_DBT_ROOT_PATH / "data") profile_config = ProfileConfig( @@ -62,7 +62,8 @@ def convert_exposure(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs): node_converters={ DbtResourceType("source"): convert_source, # known dbt node type to Cosmos (part of DbtResourceType) DbtResourceType("exposure"): convert_exposure, # dbt node type new to Cosmos (will be added to DbtResourceType) - } + }, + env_vars={"DBT_SQLITE_PATH": DBT_SQLITE_PATH}, ) @@ -73,7 +74,7 @@ def convert_exposure(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs): ), profile_config=profile_config, render_config=render_config, - operator_args={"append_env": True}, + operator_args={"env": {"DBT_SQLITE_PATH": DBT_SQLITE_PATH}}, # normal dag parameters schedule_interval="@daily", start_date=datetime(2023, 1, 1), diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index 3e3218259..0593ba5e4 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -392,7 +392,11 @@ def test_load_via_dbt_ls_with_sources(load_method): dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name, manifest_path=SAMPLE_MANIFEST_SOURCE if load_method == "load_from_dbt_manifest" else None, ), - render_config=RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name, dbt_deps=False), + render_config=RenderConfig( + dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name, + dbt_deps=False, + env_vars={"DBT_SQLITE_PATH": str(DBT_PROJECTS_ROOT_DIR / "data")}, + ), execution_config=ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name), profile_config=ProfileConfig( profile_name="simple", diff --git a/tests/dbt/test_project.py b/tests/dbt/test_project.py index bd2555c98..0fdb2508d 100644 --- a/tests/dbt/test_project.py +++ b/tests/dbt/test_project.py @@ -1,5 +1,7 @@ from pathlib import Path -from cosmos.dbt.project import create_symlinks +from cosmos.dbt.project import create_symlinks, environ +import os +from unittest.mock import patch DBT_PROJECTS_ROOT_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt" @@ -13,3 +15,21 @@ def test_create_symlinks(tmp_path): for child in tmp_dir.iterdir(): assert child.is_symlink() assert child.name not in ("logs", "target", "profiles.yml", "dbt_packages") + + +@patch.dict(os.environ, {"ORIGINAL_VAR": "value"}) +def test_environ_context_manager(): + # Define the expected environment variables + expected_env_vars = {"DBT_PROJECTS_ROOT_DIR": "/path/to/dbt/projects", "DBT_LOG_LEVEL": "debug"} + # Use the environ context manager + with environ(expected_env_vars): + # Check if the environment variables are set correctly + for key, value in expected_env_vars.items(): + assert value == os.environ.get(key) + # Check if the original environment variables are still set + assert "value" == os.environ.get("ORIGINAL_VAR") + # Check if the environment variables are unset after exiting the context manager + for key in expected_env_vars.keys(): + assert os.environ.get(key) is None + # Check if the original environment variables are still set + assert "value" == os.environ.get("ORIGINAL_VAR")