Skip to content

Commit

Permalink
Add support for env vars in RenderConfig for dbt ls parsing (#690)
Browse files Browse the repository at this point in the history
Currently, there is a workaround to have environment variables that are
required when parsing a dbt project with the dbt ls load mode by setting
them with `os.environ` in the DAG file.

This is what is currently done in the cosmos dev dag
[here](https://github.com/astronomer/astronomer-cosmos/blob/e23a445b30ca391842dae870260cc7ce799d4d5c/dev/dags/example_cosmos_sources.py#L29)
since that env var is required for parsing with dbt ls. The problem with
setting `os.environ` in that python file is that for the sqlite
integration test it was enabling this
[test](https://github.com/astronomer/astronomer-cosmos/blob/e23a445b30ca391842dae870260cc7ce799d4d5c/tests/dbt/test_graph.py#L388)
to unexpectedly pass (which also requires that env var).

This PR adds support for `env_vars` as an argument for `RenderConfig`
and sets/unsets the environment variables in a context manager for the
dbt ls graph parsing.

Closes: #5
Closes: #646
  • Loading branch information
jbandoro authored Nov 23, 2023
1 parent 8f7a04b commit 1734365
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 7 deletions.
2 changes: 2 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class RenderConfig:
:param dbt_deps: Configure to run dbt deps when using dbt ls for dag parsing
:param node_converters: a dictionary mapping a ``DbtResourceType`` into a callable. Users can control how to render dbt nodes in Airflow. Only supported when using ``load_method=LoadMode.DBT_MANIFEST`` or ``LoadMode.DBT_LS``.
:param dbt_executable_path: The path to the dbt executable for dag generation. Defaults to dbt if available on the path.
:param env_vars: A dictionary of environment variables for rendering. Only supported when using ``LoadMode.DBT_LS``.
:param dbt_project_path Configures the DBT project location accessible on the airflow controller for DAG rendering. Mutually Exclusive with ProjectConfig.dbt_project_path. Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM``.
"""

Expand All @@ -53,6 +54,7 @@ class RenderConfig:
dbt_deps: bool = True
node_converters: dict[DbtResourceType, Callable[..., Any]] | None = None
dbt_executable_path: str | Path = get_system_dbt()
env_vars: dict[str, str] = field(default_factory=dict)
dbt_project_path: InitVar[str | Path | None] = None

project_path: Path | None = field(init=False)
Expand Down
6 changes: 4 additions & 2 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LoadMode,
)
from cosmos.dbt.parser.project import LegacyDbtProject
from cosmos.dbt.project import create_symlinks
from cosmos.dbt.project import create_symlinks, environ
from cosmos.dbt.selector import select_nodes
from cosmos.log import get_logger

Expand Down Expand Up @@ -234,7 +234,9 @@ def load_via_dbt_ls(self) -> None:
tmpdir_path = Path(tmpdir)
create_symlinks(self.render_config.project_path, tmpdir_path)

with self.profile_config.ensure_profile(use_mock_values=True) as profile_values:
with self.profile_config.ensure_profile(use_mock_values=True) as profile_values, environ(
self.render_config.env_vars
):
(profile_path, env_vars) = profile_values
env = os.environ.copy()
env.update(env_vars)
Expand Down
21 changes: 21 additions & 0 deletions cosmos/dbt/project.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

from pathlib import Path
import os
from cosmos.constants import (
DBT_LOG_DIR_NAME,
DBT_TARGET_DIR_NAME,
)
from contextlib import contextmanager
from typing import Generator


def create_symlinks(project_path: Path, tmp_dir: Path) -> None:
Expand All @@ -12,3 +16,20 @@ def create_symlinks(project_path: Path, tmp_dir: Path) -> None:
for child_name in os.listdir(project_path):
if child_name not in ignore_paths:
os.symlink(project_path / child_name, tmp_dir / child_name)


@contextmanager
def environ(env_vars: dict[str, str]) -> Generator[None, None, None]:
"""Temporarily set environment variables inside the context manager and restore
when exiting.
"""
original_env = {key: os.getenv(key) for key in env_vars}
os.environ.update(env_vars)
try:
yield
finally:
for key, value in original_env.items():
if value is None:
del os.environ[key]
else:
os.environ[key] = value
7 changes: 4 additions & 3 deletions dev/dags/example_cosmos_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))

os.environ["DBT_SQLITE_PATH"] = str(DEFAULT_DBT_ROOT_PATH / "data")
DBT_SQLITE_PATH = str(DEFAULT_DBT_ROOT_PATH / "data")


profile_config = ProfileConfig(
Expand Down Expand Up @@ -62,7 +62,8 @@ def convert_exposure(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs):
node_converters={
DbtResourceType("source"): convert_source, # known dbt node type to Cosmos (part of DbtResourceType)
DbtResourceType("exposure"): convert_exposure, # dbt node type new to Cosmos (will be added to DbtResourceType)
}
},
env_vars={"DBT_SQLITE_PATH": DBT_SQLITE_PATH},
)


Expand All @@ -73,7 +74,7 @@ def convert_exposure(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs):
),
profile_config=profile_config,
render_config=render_config,
operator_args={"append_env": True},
operator_args={"env": {"DBT_SQLITE_PATH": DBT_SQLITE_PATH}},
# normal dag parameters
schedule_interval="@daily",
start_date=datetime(2023, 1, 1),
Expand Down
1 change: 1 addition & 0 deletions docs/configuration/render-config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The ``RenderConfig`` class takes the following arguments:
- ``dbt_deps``: A Boolean to run dbt deps when using dbt ls for dag parsing. Default True
- ``node_converters``: a dictionary mapping a ``DbtResourceType`` into a callable. Users can control how to render dbt nodes in Airflow. Only supported when using ``load_method=LoadMode.DBT_MANIFEST`` or ``LoadMode.DBT_LS``. Find more information below.
- ``dbt_executable_path``: The path to the dbt executable for dag generation. Defaults to dbt if available on the path.
- ``env_vars``: A dictionary of environment variables for rendering. Only supported when using ``load_method=LoadMode.DBT_LS``.
- ``dbt_project_path``: Configures the DBT project location accessible on their airflow controller for DAG rendering - Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM``

Customizing how nodes are rendered (experimental)
Expand Down
6 changes: 5 additions & 1 deletion tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,11 @@ def test_load_via_dbt_ls_with_sources(load_method):
dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name,
manifest_path=SAMPLE_MANIFEST_SOURCE if load_method == "load_from_dbt_manifest" else None,
),
render_config=RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name, dbt_deps=False),
render_config=RenderConfig(
dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name,
dbt_deps=False,
env_vars={"DBT_SQLITE_PATH": str(DBT_PROJECTS_ROOT_DIR / "data")},
),
execution_config=ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name),
profile_config=ProfileConfig(
profile_name="simple",
Expand Down
22 changes: 21 additions & 1 deletion tests/dbt/test_project.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pathlib import Path
from cosmos.dbt.project import create_symlinks
from cosmos.dbt.project import create_symlinks, environ
import os
from unittest.mock import patch

DBT_PROJECTS_ROOT_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt"

Expand All @@ -13,3 +15,21 @@ def test_create_symlinks(tmp_path):
for child in tmp_dir.iterdir():
assert child.is_symlink()
assert child.name not in ("logs", "target", "profiles.yml", "dbt_packages")


@patch.dict(os.environ, {"VAR1": "value1", "VAR2": "value2"})
def test_environ_context_manager():
# Define the expected environment variables
expected_env_vars = {"VAR2": "new_value2", "VAR3": "value3"}
# Use the environ context manager
with environ(expected_env_vars):
# Check if the environment variables are set correctly
for key, value in expected_env_vars.items():
assert value == os.environ.get(key)
# Check if the original non-overlapping environment variable is still set
assert "value1" == os.environ.get("VAR1")
# Check if the environment variables are unset after exiting the context manager
assert os.environ.get("VAR3") is None
# Check if the original environment variables are still set
assert "value1" == os.environ.get("VAR1")
assert "value2" == os.environ.get("VAR2")

0 comments on commit 1734365

Please sign in to comment.