Skip to content

Commit

Permalink
propogate use of RenderConfig.project_path and ExecutionConfig.projec…
Browse files Browse the repository at this point in the history
…t_path to dbtGraph
  • Loading branch information
MrBones757 committed Nov 2, 2023
1 parent 3398387 commit 39c85dd
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 45 deletions.
1 change: 1 addition & 0 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(
dbt_graph = DbtGraph(
project=project_config,
render_config=render_config,
execution_config=execution_config,
dbt_cmd=render_config.dbt_executable_path,
profile_config=profile_config,
operator_args=operator_args,
Expand Down
73 changes: 39 additions & 34 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from subprocess import PIPE, Popen
from typing import Any

from cosmos.config import ProfileConfig, ProjectConfig, RenderConfig
from cosmos.config import ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig
from cosmos.constants import (
DBT_LOG_DIR_NAME,
DBT_LOG_FILENAME,
Expand Down Expand Up @@ -53,12 +53,12 @@ class DbtNode:
has_test: bool = False


def create_symlinks(dbt_project_path: Path, tmp_dir: Path) -> None:
def create_symlinks(project_path: Path, tmp_dir: Path) -> None:
"""Helper function to create symlinks to the dbt project files."""
ignore_paths = (DBT_LOG_DIR_NAME, DBT_TARGET_DIR_NAME, "dbt_packages", "profiles.yml")
for child_name in os.listdir(dbt_project_path):
for child_name in os.listdir(project_path):
if child_name not in ignore_paths:
os.symlink(dbt_project_path / child_name, tmp_dir / child_name)
os.symlink(project_path / child_name, tmp_dir / child_name)


def run_command(command: list[str], tmp_dir: Path, env_vars: dict[str, str]) -> str:
Expand Down Expand Up @@ -88,7 +88,7 @@ def run_command(command: list[str], tmp_dir: Path, env_vars: dict[str, str]) ->
return stdout


def parse_dbt_ls_output(dbt_project_path: Path, ls_stdout: str) -> dict[str, DbtNode]:
def parse_dbt_ls_output(project_path: Path, ls_stdout: str) -> dict[str, DbtNode]:
"""Parses the output of `dbt ls` into a dictionary of `DbtNode` instances."""
nodes = {}
for line in ls_stdout.split("\n"):
Expand All @@ -102,7 +102,7 @@ def parse_dbt_ls_output(dbt_project_path: Path, ls_stdout: str) -> dict[str, Dbt
unique_id=node_dict["unique_id"],
resource_type=DbtResourceType(node_dict["resource_type"]),
depends_on=node_dict.get("depends_on", {}).get("nodes", []),
file_path=dbt_project_path / node_dict["original_file_path"],
file_path=project_path / node_dict["original_file_path"],
tags=node_dict["tags"],
config=node_dict["config"],
)
Expand Down Expand Up @@ -135,13 +135,15 @@ def __init__(
self,
project: ProjectConfig,
render_config: RenderConfig = RenderConfig(),
execution_config: ExecutionConfig = ExecutionConfig(),
profile_config: ProfileConfig | None = None,
dbt_cmd: str = get_system_dbt(),
operator_args: dict[str, Any] | None = None,
):
self.project = project
self.render_config = render_config
self.profile_config = profile_config
self.execution_config = execution_config
self.operator_args = operator_args or {}
self.dbt_cmd = dbt_cmd

Expand Down Expand Up @@ -181,7 +183,7 @@ def load(
else:
load_method[method]()

def run_dbt_ls(self, dbt_project_path: Path, tmp_dir: Path, env_vars: dict[str, str]) -> dict[str, DbtNode]:
def run_dbt_ls(self, project_path: Path, tmp_dir: Path, env_vars: dict[str, str]) -> dict[str, DbtNode]:
"""Runs dbt ls command and returns the parsed nodes."""
ls_command = [self.dbt_cmd, "ls", "--output", "json"]

Expand All @@ -203,7 +205,7 @@ def run_dbt_ls(self, dbt_project_path: Path, tmp_dir: Path, env_vars: dict[str,
for line in logfile:
logger.debug(line.strip())

nodes = parse_dbt_ls_output(dbt_project_path, stdout)
nodes = parse_dbt_ls_output(project_path, stdout)
return nodes

def load_via_dbt_ls(self) -> None:
Expand All @@ -218,28 +220,24 @@ def load_via_dbt_ls(self) -> None:
* self.nodes
* self.filtered_nodes
"""
logger.info(
"Trying to parse the dbt project `%s` in `%s` using dbt ls...",
self.project.project_name,
self.project.dbt_project_path,
)
if self.project.dbt_project_path is None:
raise CosmosLoadDbtException("Unable to dbt ls load a project without a project path.")
logger.info(f"Trying to parse the dbt project in `{self.render_config.project_path}` using dbt ls...")
if not self.render_config.project_path or not self.execution_config.project_path:
raise CosmosLoadDbtException(
"Unable to load project via dbt ls without RenderConfig.dbt_project_path and ExecutionConfig.dbt_project_path"
)

if not self.project.dbt_project_path or not self.profile_config:
raise CosmosLoadDbtException("Unable to load dbt project without project files and a profile config")
if not self.profile_config:
raise CosmosLoadDbtException("Unable to load project via dbt ls without a profile config.")

if not shutil.which(self.dbt_cmd):
raise CosmosLoadDbtException(f"Unable to find the dbt executable: {self.dbt_cmd}")

with tempfile.TemporaryDirectory() as tmpdir:
logger.info(
"Content of the dbt project dir <%s>: `%s`",
self.project.dbt_project_path,
os.listdir(self.project.dbt_project_path),
f"Content of the dbt project dir {self.render_config.project_path}: `{os.listdir(self.render_config.project_path)}`"
)
tmpdir_path = Path(tmpdir)
create_symlinks(self.project.dbt_project_path, tmpdir_path)
create_symlinks(self.render_config.project_path, tmpdir_path)

with self.profile_config.ensure_profile(use_mock_values=True) as profile_values:
(profile_path, env_vars) = profile_values
Expand Down Expand Up @@ -267,7 +265,7 @@ def load_via_dbt_ls(self) -> None:
stdout = run_command(deps_command, tmpdir_path, env)
logger.debug("dbt deps output: %s", stdout)

nodes = self.run_dbt_ls(self.project.dbt_project_path, tmpdir_path, env)
nodes = self.run_dbt_ls(self.execution_config.project_path, tmpdir_path, env)

self.nodes = nodes
self.filtered_nodes = nodes
Expand All @@ -291,14 +289,16 @@ def load_via_custom_parser(self) -> None:
"""
logger.info("Trying to parse the dbt project `%s` using a custom Cosmos method...", self.project.project_name)

if not self.project.dbt_project_path or not self.project.models_path or not self.project.seeds_path:
raise CosmosLoadDbtException("Unable to load dbt project without project files")
if not self.render_config.project_path or not self.execution_config.project_path:
raise CosmosLoadDbtException(
"Unable to load dbt project without RenderConfig.dbt_project_path and ExecutionConfig.dbt_project_path"
)

project = LegacyDbtProject(
project_name=self.project.dbt_project_path.stem,
dbt_root_path=self.project.dbt_project_path.parent.as_posix(),
dbt_models_dir=self.project.models_path.stem,
dbt_seeds_dir=self.project.seeds_path.stem,
project_name=self.render_config.project_path.stem,
dbt_root_path=self.render_config.project_path.parent.as_posix(),
dbt_models_dir=self.project.models_path.stem if self.project.models_path else "models",
dbt_seeds_dir=self.project.seeds_path.stem if self.project.seeds_path else "seeds",
operator_args=self.operator_args,
)
nodes = {}
Expand All @@ -312,15 +312,19 @@ def load_via_custom_parser(self) -> None:
unique_id=model_name,
resource_type=DbtResourceType(model.type.value),
depends_on=list(model.config.upstream_models),
file_path=model.path,
file_path=Path(
model.path.as_posix().replace(
self.render_config.project_path.as_posix(), self.execution_config.project_path.as_posix()
)
),
tags=[],
config=config,
)
nodes[model_name] = node

self.nodes = nodes
self.filtered_nodes = select_nodes(
project_dir=self.project.dbt_project_path,
project_dir=self.execution_config.project_path,
nodes=nodes,
select=self.render_config.select,
exclude=self.render_config.exclude,
Expand Down Expand Up @@ -350,6 +354,9 @@ def load_from_dbt_manifest(self) -> None:
if not self.project.is_manifest_available():
raise CosmosLoadDbtException(f"Unable to load manifest using {self.project.manifest_path}")

if not self.execution_config.project_path:
raise CosmosLoadDbtException("Unable to load manifest without ExecutionConfig.dbt_project_path")

nodes = {}
with open(self.project.manifest_path) as fp: # type: ignore[arg-type]
manifest = json.load(fp)
Expand All @@ -361,9 +368,7 @@ def load_from_dbt_manifest(self) -> None:
unique_id=unique_id,
resource_type=DbtResourceType(node_dict["resource_type"]),
depends_on=node_dict.get("depends_on", {}).get("nodes", []),
file_path=self.project.dbt_project_path / Path(node_dict["original_file_path"])
if self.project.dbt_project_path
else Path(node_dict["original_file_path"]),
file_path=self.execution_config.project_path / Path(node_dict["original_file_path"]),
tags=node_dict["tags"],
config=node_dict["config"],
)
Expand All @@ -372,7 +377,7 @@ def load_from_dbt_manifest(self) -> None:

self.nodes = nodes
self.filtered_nodes = select_nodes(
project_dir=self.project.dbt_project_path,
project_dir=self.execution_config.project_path,
nodes=nodes,
select=self.render_config.select,
exclude=self.render_config.exclude,
Expand Down
67 changes: 56 additions & 11 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from cosmos.config import ProfileConfig, ProjectConfig, RenderConfig
from cosmos.config import ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig
from cosmos.constants import DbtResourceType, ExecutionMode
from cosmos.dbt.graph import (
CosmosLoadDbtException,
Expand Down Expand Up @@ -57,7 +57,13 @@ def test_load_via_manifest_with_exclude(project_name, manifest_filepath, model_f
profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml",
)
render_config = RenderConfig(exclude=["config.materialized:table"])
dbt_graph = DbtGraph(project=project_config, profile_config=profile_config, render_config=render_config)
execution_config = ExecutionConfig(dbt_project_path=project_config.dbt_project_path)
dbt_graph = DbtGraph(
project=project_config,
execution_config=execution_config,
profile_config=profile_config,
render_config=render_config,
)
dbt_graph.load_from_dbt_manifest()

assert len(dbt_graph.nodes) == 28
Expand Down Expand Up @@ -292,30 +298,48 @@ def test_load_via_dbt_ls_without_exclude(project_name):

def test_load_via_custom_without_project_path():
project_config = ProjectConfig(manifest_path=SAMPLE_MANIFEST, project_name="test")
dbt_graph = DbtGraph(dbt_cmd="/inexistent/dbt", project=project_config)
execution_config = ExecutionConfig()
render_config = RenderConfig()
dbt_graph = DbtGraph(
dbt_cmd="/inexistent/dbt",
project=project_config,
execution_config=execution_config,
render_config=render_config,
)
with pytest.raises(CosmosLoadDbtException) as err_info:
dbt_graph.load_via_custom_parser()

expected = "Unable to load dbt project without project files"
expected = "Unable to load dbt project without RenderConfig.dbt_project_path and ExecutionConfig.dbt_project_path"
assert err_info.value.args[0] == expected


def test_load_via_dbt_ls_without_profile():
project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
dbt_graph = DbtGraph(dbt_cmd="/inexistent/dbt", project=project_config)
execution_config = ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
render_config = RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
dbt_graph = DbtGraph(
dbt_cmd="/inexistent/dbt",
project=project_config,
execution_config=execution_config,
render_config=render_config,
)
with pytest.raises(CosmosLoadDbtException) as err_info:
dbt_graph.load_via_dbt_ls()

expected = "Unable to load dbt project without project files and a profile config"
expected = "Unable to load project via dbt ls without a profile config."
assert err_info.value.args[0] == expected


def test_load_via_dbt_ls_with_invalid_dbt_path():
project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
execution_config = ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
render_config = RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
with patch("pathlib.Path.exists", return_value=True):
dbt_graph = DbtGraph(
dbt_cmd="/inexistent/dbt",
project=project_config,
execution_config=execution_config,
render_config=render_config,
profile_config=ProfileConfig(
profile_name="default",
target_name="default",
Expand Down Expand Up @@ -446,12 +470,19 @@ def test_load_via_dbt_ls_with_runtime_error_in_stdout(mock_popen_communicate):
@pytest.mark.parametrize("project_name", ("jaffle_shop", "jaffle_shop_python"))
def test_load_via_load_via_custom_parser(project_name):
project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name)
execution_config = ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
render_config = RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
profile_config = ProfileConfig(
profile_name="test",
target_name="test",
profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml",
)
dbt_graph = DbtGraph(project=project_config, profile_config=profile_config)
dbt_graph = DbtGraph(
project=project_config,
profile_config=profile_config,
render_config=render_config,
execution_config=execution_config,
)

dbt_graph.load_via_custom_parser()

Expand All @@ -464,12 +495,13 @@ def test_update_node_dependency_called(mock_update_node_dependency):
project_config = ProjectConfig(
dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME, manifest_path=SAMPLE_MANIFEST
)
execution_config = ExecutionConfig(dbt_project_path=project_config.dbt_project_path)
profile_config = ProfileConfig(
profile_name="test",
target_name="test",
profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml",
)
dbt_graph = DbtGraph(project=project_config, profile_config=profile_config)
dbt_graph = DbtGraph(project=project_config, execution_config=execution_config, profile_config=profile_config)
dbt_graph.load()

assert mock_update_node_dependency.called
Expand All @@ -484,7 +516,8 @@ def test_update_node_dependency_target_exist():
target_name="test",
profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml",
)
dbt_graph = DbtGraph(project=project_config, profile_config=profile_config)
execution_config = ExecutionConfig(dbt_project_path=project_config.dbt_project_path)
dbt_graph = DbtGraph(project=project_config, execution_config=execution_config, profile_config=profile_config)
dbt_graph.load()

for _, nodes in dbt_graph.nodes.items():
Expand All @@ -503,7 +536,13 @@ def test_update_node_dependency_test_not_exist():
profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml",
)
render_config = RenderConfig(exclude=["config.materialized:test"])
dbt_graph = DbtGraph(project=project_config, profile_config=profile_config, render_config=render_config)
execution_config = ExecutionConfig(dbt_project_path=project_config.dbt_project_path)
dbt_graph = DbtGraph(
project=project_config,
execution_config=execution_config,
profile_config=profile_config,
render_config=render_config,
)
dbt_graph.load_from_dbt_manifest()

for _, nodes in dbt_graph.filtered_nodes.items():
Expand All @@ -520,7 +559,13 @@ def test_tag_selected_node_test_exist():
profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml",
)
render_config = RenderConfig(select=["tag:test_tag"])
dbt_graph = DbtGraph(project=project_config, profile_config=profile_config, render_config=render_config)
execution_config = ExecutionConfig(dbt_project_path=project_config.dbt_project_path)
dbt_graph = DbtGraph(
project=project_config,
execution_config=execution_config,
profile_config=profile_config,
render_config=render_config,
)
dbt_graph.load_from_dbt_manifest()

assert len(dbt_graph.filtered_nodes) > 0
Expand Down

0 comments on commit 39c85dd

Please sign in to comment.