Skip to content

Commit

Permalink
Added new fields to ExecutionConfig and RenderConfig. Updated Convert…
Browse files Browse the repository at this point in the history
…er for new Fields. Simplified Graph interaction to consume RenderConfig natively. Updates tests to reflect changes
  • Loading branch information
MrBones757 committed Oct 28, 2023
1 parent c1a1406 commit cf2cf6f
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 70 deletions.
16 changes: 15 additions & 1 deletion cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import contextlib
import tempfile
from dataclasses import dataclass, field
from dataclasses import InitVar, dataclass, field
from pathlib import Path
from typing import Any, Iterator, Callable

Expand Down Expand Up @@ -40,6 +40,13 @@ class RenderConfig:
exclude: list[str] = field(default_factory=list)
dbt_deps: bool = True
node_converters: dict[DbtResourceType, Callable[..., Any]] | None = None
dbt_executable_path: str | Path = get_system_dbt()
dbt_project_path: InitVar[str | Path | None] = None

project_path: Path | None = field(init=False)

def __post_init__(self, dbt_project_path: str | Path | None) -> None:
self.project_path = Path(dbt_project_path) if dbt_project_path else None


class ProjectConfig:
Expand Down Expand Up @@ -217,3 +224,10 @@ class ExecutionConfig:
execution_mode: ExecutionMode = ExecutionMode.LOCAL
test_indirect_selection: TestIndirectSelection = TestIndirectSelection.EAGER
dbt_executable_path: str | Path = get_system_dbt()

dbt_project_path: InitVar[str | Path | None] = None

project_path: Path | None = field(init=False)

def __post_init__(self, dbt_project_path: str | Path | None) -> None:
self.project_path = Path(dbt_project_path) if dbt_project_path else None
79 changes: 48 additions & 31 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from cosmos.exceptions import CosmosValueError
from cosmos.log import get_logger


logger = get_logger(__name__)


Expand Down Expand Up @@ -92,8 +91,8 @@ def __init__(
self,
project_config: ProjectConfig,
profile_config: ProfileConfig,
execution_config: ExecutionConfig = ExecutionConfig(),
render_config: RenderConfig = RenderConfig(),
execution_config: ExecutionConfig | None = None,
render_config: RenderConfig | None = None,
dag: DAG | None = None,
task_group: TaskGroup | None = None,
operator_args: dict[str, Any] | None = None,
Expand All @@ -103,19 +102,40 @@ def __init__(
) -> None:
project_config.validate_project()

emit_datasets = render_config.emit_datasets
test_behavior = render_config.test_behavior
select = render_config.select
exclude = render_config.exclude
dbt_deps = render_config.dbt_deps
execution_mode = execution_config.execution_mode
test_indirect_selection = execution_config.test_indirect_selection
load_mode = render_config.load_method
dbt_executable_path = execution_config.dbt_executable_path
node_converters = render_config.node_converters

if not project_config.dbt_project_path:
raise CosmosValueError("A Project Path in ProjectConfig is required for generating a Task Operators.")
if not execution_config:
execution_config = ExecutionConfig()
if not render_config:
render_config = RenderConfig()

# Since we now support both project_config.dbt_project_path, render_config.project_path and execution_config.project_path
# We need to ensure that only one interface is being used.
if project_config.dbt_project_path and (render_config.project_path or execution_config.project_path):
print(f"RenderConfig: {render_config.project_path}")
print(f"ExecutionConfig: {execution_config.project_path}")
print(f"ProjectConfig: {project_config.dbt_project_path}")
raise CosmosValueError(
"ProjectConfig.dbt_project_path is mutually exclusive with RenderConfig.dbt_project_path and ExecutionConfig.dbt_project_path."
+ "If using RenderConfig.dbt_project_path or ExecutionConfig.dbt_project_path, ProjectConfig.dbt_project_path should be None"
)

# If we are using the old interface, we should migrate it to the new interface
# This is safe to do now since we have validated which config interface we're using
if project_config.dbt_project_path:
render_config.project_path = project_config.dbt_project_path
execution_config.project_path = project_config.dbt_project_path

# At this point, execution_config.project_path should always be non-null
if not execution_config.project_path:
raise CosmosValueError(
"ExecutionConfig.dbt_project_path is required for the execution of dbt tasks in all execution modes."
)

# We now have a guaranteed execution_config.project_path, but still need to process render_config.project_path
# We require render_config.project_path when we dont have a manifest
if not project_config.manifest_path and not render_config.project_path:
raise CosmosValueError(
"RenderConfig.dbt_project_path is required for rendering an airflow DAG from a DBT Graph if no manifest is provided."
)

profile_args = {}
if profile_config.profile_mapping:
Expand All @@ -136,36 +156,33 @@ def __init__(
# We may want to consider defaulting this value in our actual ProjceConfig class?
dbt_graph = DbtGraph(
project=project_config,
exclude=exclude,
select=select,
dbt_cmd=dbt_executable_path,
render_config=render_config,
dbt_cmd=render_config.dbt_executable_path,
profile_config=profile_config,
operator_args=operator_args,
dbt_deps=dbt_deps,
)
dbt_graph.load(method=load_mode, execution_mode=execution_mode)
dbt_graph.load(method=render_config.load_method, execution_mode=execution_config.execution_mode)

task_args = {
**operator_args,
# the following args may be only needed for local / venv:
"project_dir": project_config.dbt_project_path,
"project_dir": execution_config.project_path,
"profile_config": profile_config,
"emit_datasets": emit_datasets,
"emit_datasets": render_config.emit_datasets,
}
if dbt_executable_path:
task_args["dbt_executable_path"] = dbt_executable_path
if execution_config.dbt_executable_path:
task_args["dbt_executable_path"] = execution_config.dbt_executable_path

validate_arguments(select, exclude, profile_args, task_args)
validate_arguments(render_config.select, render_config.exclude, profile_args, task_args)

build_airflow_graph(
nodes=dbt_graph.filtered_nodes,
dag=dag or (task_group and task_group.dag),
task_group=task_group,
execution_mode=execution_mode,
execution_mode=execution_config.execution_mode,
task_args=task_args,
test_behavior=test_behavior,
test_indirect_selection=test_indirect_selection,
test_behavior=render_config.test_behavior,
test_indirect_selection=execution_config.test_indirect_selection,
dbt_project_name=project_config.project_name,
on_warning_callback=on_warning_callback,
node_converters=node_converters,
node_converters=render_config.node_converters,
)
35 changes: 17 additions & 18 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from subprocess import PIPE, Popen
from typing import Any

from cosmos.config import ProfileConfig, ProjectConfig
from cosmos.config import ProfileConfig, ProjectConfig, RenderConfig
from cosmos.constants import (
DBT_LOG_DIR_NAME,
DBT_LOG_FILENAME,
Expand Down Expand Up @@ -77,22 +77,15 @@ class DbtGraph:
def __init__(
self,
project: ProjectConfig,
render_config: RenderConfig = RenderConfig(),
profile_config: ProfileConfig | None = None,
exclude: list[str] | None = None,
select: list[str] | None = None,
dbt_cmd: str = get_system_dbt(),
operator_args: dict[str, Any] | None = None,
dbt_deps: bool | None = True,
):
self.project = project
self.exclude = exclude or []
self.select = select or []
self.render_config = render_config
self.profile_config = profile_config
self.operator_args = operator_args or {}
self.dbt_deps = dbt_deps

# specific to loading using ls
self.dbt_deps = dbt_deps
self.dbt_cmd = dbt_cmd

def load(
Expand Down Expand Up @@ -191,7 +184,7 @@ def load_via_dbt_ls(self) -> None:
env[DBT_LOG_PATH_ENVVAR] = str(log_dir)
env[DBT_TARGET_PATH_ENVVAR] = str(target_dir)

if self.dbt_deps:
if self.render_config.dbt_deps:
deps_command = [self.dbt_cmd, "deps"]
deps_command.extend(local_flags)
logger.info("Running command: `%s`", " ".join(deps_command))
Expand All @@ -214,11 +207,11 @@ def load_via_dbt_ls(self) -> None:

ls_command = [self.dbt_cmd, "ls", "--output", "json"]

if self.exclude:
ls_command.extend(["--exclude", *self.exclude])
if self.render_config.exclude:
ls_command.extend(["--exclude", *self.render_config.exclude])

if self.select:
ls_command.extend(["--select", *self.select])
if self.render_config.select:
ls_command.extend(["--select", *self.render_config.select])

ls_command.extend(local_flags)

Expand Down Expand Up @@ -247,7 +240,7 @@ def load_via_dbt_ls(self) -> None:

if 'Run "dbt deps" to install package dependencies' in stdout:
raise CosmosLoadDbtException(
"Unable to run dbt ls command due to missing dbt_packages. Set render_config.dbt_deps=True."
"Unable to run dbt ls command due to missing dbt_packages. Set RenderConfig.dbt_deps=True."
)

if returncode or "Error" in stdout:
Expand Down Expand Up @@ -324,7 +317,10 @@ def load_via_custom_parser(self) -> None:

self.nodes = nodes
self.filtered_nodes = select_nodes(
project_dir=self.project.dbt_project_path, nodes=nodes, select=self.select, exclude=self.exclude
project_dir=self.project.dbt_project_path,
nodes=nodes,
select=self.render_config.select,
exclude=self.render_config.exclude,
)

self.update_node_dependency()
Expand Down Expand Up @@ -373,7 +369,10 @@ def load_from_dbt_manifest(self) -> None:

self.nodes = nodes
self.filtered_nodes = select_nodes(
project_dir=self.project.dbt_project_path, nodes=nodes, select=self.select, exclude=self.exclude
project_dir=self.project.dbt_project_path,
nodes=nodes,
select=self.render_config.select,
exclude=self.render_config.exclude,
)

self.update_node_dependency()
Expand Down
8 changes: 5 additions & 3 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from cosmos.config import ProfileConfig, ProjectConfig
from cosmos.config import ProfileConfig, ProjectConfig, RenderConfig
from cosmos.constants import DbtResourceType, ExecutionMode
from cosmos.dbt.graph import CosmosLoadDbtException, DbtGraph, LoadMode
from cosmos.profiles import PostgresUserPasswordProfileMapping
Expand Down Expand Up @@ -48,7 +48,8 @@ def test_load_via_manifest_with_exclude(project_name, manifest_filepath, model_f
target_name="test",
profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml",
)
dbt_graph = DbtGraph(project=project_config, profile_config=profile_config, exclude=["config.materialized:table"])
render_config = RenderConfig(exclude=["config.materialized:table"])
dbt_graph = DbtGraph(project=project_config, profile_config=profile_config, render_config=render_config)
dbt_graph.load_from_dbt_manifest()

assert len(dbt_graph.nodes) == 28
Expand Down Expand Up @@ -497,7 +498,8 @@ def test_update_node_dependency_test_not_exist():
target_name="test",
profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml",
)
dbt_graph = DbtGraph(project=project_config, profile_config=profile_config, exclude=["config.materialized:test"])
render_config = RenderConfig(exclude=["config.materialized:test"])
dbt_graph = DbtGraph(project=project_config, profile_config=profile_config, render_config=render_config)
dbt_graph.load_from_dbt_manifest()

for _, nodes in dbt_graph.filtered_nodes.items():
Expand Down
26 changes: 11 additions & 15 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ def test_init_with_no_params():
"""
with pytest.raises(CosmosValueError) as err_info:
ProjectConfig()
print(err_info.value.args[0])
assert err_info.value.args[0] == (
"ProjectConfig requires dbt_project_path and/or manifest_path to be defined."
"If only manifest_path is defined, project_name must also be defined."
)
assert err_info.value.args[0] == (
"ProjectConfig requires dbt_project_path and/or manifest_path to be defined. "
"If only manifest_path is defined, project_name must also be defined."
)


def test_init_with_manifest_path_and_not_project_path_and_not_project_name_fails():
Expand All @@ -55,11 +54,10 @@ def test_init_with_manifest_path_and_not_project_path_and_not_project_name_fails
"""
with pytest.raises(CosmosValueError) as err_info:
ProjectConfig(manifest_path=DBT_PROJECTS_ROOT_DIR / "manifest.json")
print(err_info.value.args[0])
assert err_info.value.args[0] == (
"ProjectConfig requires dbt_project_path and/or manifest_path to be defined."
"If only manifest_path is defined, project_name must also be defined."
)
assert err_info.value.args[0] == (
"ProjectConfig requires dbt_project_path and/or manifest_path to be defined. "
"If only manifest_path is defined, project_name must also be defined."
)


def test_validate_with_project_path_and_manifest_path_succeeds():
Expand Down Expand Up @@ -95,7 +93,7 @@ def test_validate_project_missing_fails():
project_config = ProjectConfig(dbt_project_path=Path("/tmp"))
with pytest.raises(CosmosValueError) as err_info:
assert project_config.validate_project() is None
assert err_info.value.args[0] == "Could not find dbt_project.yml at /tmp/dbt_project.yml"
assert err_info.value.args[0] == "Could not find dbt_project.yml at /tmp/dbt_project.yml"


def test_is_manifest_available_is_true():
Expand All @@ -118,13 +116,11 @@ def test_project_name():
def test_profile_config_post_init():
with pytest.raises(CosmosValueError) as err_info:
ProfileConfig(profiles_yml_filepath="/tmp/some-profile", profile_name="test", target_name="test")
assert err_info.value.args[0] == "The file /tmp/some-profile does not exist."
assert err_info.value.args[0] == "The file /tmp/some-profile does not exist."


def test_profile_config_validate():
with pytest.raises(CosmosValueError) as err_info:
profile_config = ProfileConfig(profile_name="test", target_name="test")
assert profile_config.validate_profile() is None
assert (
err_info.value.args[0] == "Either profiles_yml_filepath or profile_mapping must be set to render a profile"
)
assert err_info.value.args[0] == "Either profiles_yml_filepath or profile_mapping must be set to render a profile"
7 changes: 5 additions & 2 deletions tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_converter_creates_dag_with_project_path_str(mock_load_dbt_graph, execut
)
@patch("cosmos.converter.DbtGraph.filtered_nodes", nodes)
@patch("cosmos.converter.DbtGraph.load")
def test_converter_fails_no_project_dir(mock_load_dbt_graph, execution_mode, operator_args):
def test_converter_fails_execution_config_no_project_dir(mock_load_dbt_graph, execution_mode, operator_args):
"""
This test validates that a project, given a manifest path and project name, with seeds
is able to successfully generate a converter
Expand All @@ -135,4 +135,7 @@ def test_converter_fails_no_project_dir(mock_load_dbt_graph, execution_mode, ope
render_config=render_config,
operator_args=operator_args,
)
assert err_info.value.args[0] == "A Project Path in ProjectConfig is required for generating a Task Operators."
assert (
err_info.value.args[0]
== "ExecutionConfig.project_path is required for the execution of dbt tasks in all execution modes."
)

0 comments on commit cf2cf6f

Please sign in to comment.