diff --git a/cosmos/config.py b/cosmos/config.py index 5c64193c1..40756d2bb 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -42,6 +42,7 @@ class RenderConfig: :param dbt_deps: Configure to run dbt deps when using dbt ls for dag parsing :param node_converters: a dictionary mapping a ``DbtResourceType`` into a callable. Users can control how to render dbt nodes in Airflow. Only supported when using ``load_method=LoadMode.DBT_MANIFEST`` or ``LoadMode.DBT_LS``. :param dbt_executable_path: The path to the dbt executable for dag generation. Defaults to dbt if available on the path. + :param env_vars: A dictionary of environment variables for rendering. Only supported when using ``LoadMode.DBT_LS``. :param dbt_project_path Configures the DBT project location accessible on the airflow controller for DAG rendering. Mutually Exclusive with ProjectConfig.dbt_project_path. Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM``. """ @@ -53,6 +54,7 @@ class RenderConfig: dbt_deps: bool = True node_converters: dict[DbtResourceType, Callable[..., Any]] | None = None dbt_executable_path: str | Path = get_system_dbt() + env_vars: dict[str, str] = field(default_factory=dict) dbt_project_path: InitVar[str | Path | None] = None project_path: Path | None = field(init=False) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 556ebe09a..a890c137c 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -21,7 +21,7 @@ LoadMode, ) from cosmos.dbt.parser.project import LegacyDbtProject -from cosmos.dbt.project import create_symlinks +from cosmos.dbt.project import create_symlinks, environ from cosmos.dbt.selector import select_nodes from cosmos.log import get_logger @@ -234,7 +234,9 @@ def load_via_dbt_ls(self) -> None: tmpdir_path = Path(tmpdir) create_symlinks(self.render_config.project_path, tmpdir_path) - with self.profile_config.ensure_profile(use_mock_values=True) as profile_values: + with self.profile_config.ensure_profile(use_mock_values=True) as profile_values, environ( + self.render_config.env_vars + ): (profile_path, env_vars) = profile_values env = os.environ.copy() env.update(env_vars) diff --git a/cosmos/dbt/project.py b/cosmos/dbt/project.py index 63f4fc007..14b2f5e4b 100644 --- a/cosmos/dbt/project.py +++ b/cosmos/dbt/project.py @@ -1,9 +1,13 @@ +from __future__ import annotations + from pathlib import Path import os from cosmos.constants import ( DBT_LOG_DIR_NAME, DBT_TARGET_DIR_NAME, ) +from contextlib import contextmanager +from typing import Generator def create_symlinks(project_path: Path, tmp_dir: Path) -> None: @@ -12,3 +16,20 @@ def create_symlinks(project_path: Path, tmp_dir: Path) -> None: for child_name in os.listdir(project_path): if child_name not in ignore_paths: os.symlink(project_path / child_name, tmp_dir / child_name) + + +@contextmanager +def environ(env_vars: dict[str, str]) -> Generator[None, None, None]: + """Temporarily set environment variables inside the context manager and restore + when exiting. + """ + original_env = {key: os.getenv(key) for key in env_vars} + os.environ.update(env_vars) + try: + yield + finally: + for key, value in original_env.items(): + if value is None: + del os.environ[key] + else: + os.environ[key] = value diff --git a/dev/dags/example_cosmos_sources.py b/dev/dags/example_cosmos_sources.py index 29c70db5a..157b3adb3 100644 --- a/dev/dags/example_cosmos_sources.py +++ b/dev/dags/example_cosmos_sources.py @@ -26,7 +26,7 @@ DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) -os.environ["DBT_SQLITE_PATH"] = str(DEFAULT_DBT_ROOT_PATH / "data") +DBT_SQLITE_PATH = str(DEFAULT_DBT_ROOT_PATH / "data") profile_config = ProfileConfig( @@ -62,7 +62,8 @@ def convert_exposure(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs): node_converters={ DbtResourceType("source"): convert_source, # known dbt node type to Cosmos (part of DbtResourceType) DbtResourceType("exposure"): convert_exposure, # dbt node type new to Cosmos (will be added to DbtResourceType) - } + }, + env_vars={"DBT_SQLITE_PATH": DBT_SQLITE_PATH}, ) @@ -73,7 +74,7 @@ def convert_exposure(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs): ), profile_config=profile_config, render_config=render_config, - operator_args={"append_env": True}, + operator_args={"env": {"DBT_SQLITE_PATH": DBT_SQLITE_PATH}}, # normal dag parameters schedule_interval="@daily", start_date=datetime(2023, 1, 1), diff --git a/docs/configuration/render-config.rst b/docs/configuration/render-config.rst index de0a08cdb..5e1c23824 100644 --- a/docs/configuration/render-config.rst +++ b/docs/configuration/render-config.rst @@ -14,6 +14,7 @@ The ``RenderConfig`` class takes the following arguments: - ``dbt_deps``: A Boolean to run dbt deps when using dbt ls for dag parsing. Default True - ``node_converters``: a dictionary mapping a ``DbtResourceType`` into a callable. Users can control how to render dbt nodes in Airflow. Only supported when using ``load_method=LoadMode.DBT_MANIFEST`` or ``LoadMode.DBT_LS``. Find more information below. - ``dbt_executable_path``: The path to the dbt executable for dag generation. Defaults to dbt if available on the path. +- ``env_vars``: A dictionary of environment variables for rendering. Only supported when using ``load_method=LoadMode.DBT_LS``. - ``dbt_project_path``: Configures the DBT project location accessible on their airflow controller for DAG rendering - Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM`` Customizing how nodes are rendered (experimental) diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index 224aff56e..a424976a1 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -392,7 +392,11 @@ def test_load_via_dbt_ls_with_sources(load_method): dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name, manifest_path=SAMPLE_MANIFEST_SOURCE if load_method == "load_from_dbt_manifest" else None, ), - render_config=RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name, dbt_deps=False), + render_config=RenderConfig( + dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name, + dbt_deps=False, + env_vars={"DBT_SQLITE_PATH": str(DBT_PROJECTS_ROOT_DIR / "data")}, + ), execution_config=ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name), profile_config=ProfileConfig( profile_name="simple", diff --git a/tests/dbt/test_project.py b/tests/dbt/test_project.py index bd2555c98..ec5612904 100644 --- a/tests/dbt/test_project.py +++ b/tests/dbt/test_project.py @@ -1,5 +1,7 @@ from pathlib import Path -from cosmos.dbt.project import create_symlinks +from cosmos.dbt.project import create_symlinks, environ +import os +from unittest.mock import patch DBT_PROJECTS_ROOT_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt" @@ -13,3 +15,21 @@ def test_create_symlinks(tmp_path): for child in tmp_dir.iterdir(): assert child.is_symlink() assert child.name not in ("logs", "target", "profiles.yml", "dbt_packages") + + +@patch.dict(os.environ, {"VAR1": "value1", "VAR2": "value2"}) +def test_environ_context_manager(): + # Define the expected environment variables + expected_env_vars = {"VAR2": "new_value2", "VAR3": "value3"} + # Use the environ context manager + with environ(expected_env_vars): + # Check if the environment variables are set correctly + for key, value in expected_env_vars.items(): + assert value == os.environ.get(key) + # Check if the original non-overlapping environment variable is still set + assert "value1" == os.environ.get("VAR1") + # Check if the environment variables are unset after exiting the context manager + assert os.environ.get("VAR3") is None + # Check if the original environment variables are still set + assert "value1" == os.environ.get("VAR1") + assert "value2" == os.environ.get("VAR2")