Skip to content

Commit

Permalink
Resolve errors occurring when dbt_project_path is str and partial sup…
Browse files Browse the repository at this point in the history
…port dbt_project_path=None (#605)

As part of the changes made #581, some downstream logic was missed
relating to the handling of a None and String-based project dir. This MR
attempts to remedy this issue by adding down steam support for the project
dir being None (including generation of exceptions and guarding), as
well as some property reference changes in the converter.

Closes: #601 

Co-authored-by: tabmra <[email protected]>
  • Loading branch information
MrBones757 and tabmra authored Oct 25, 2023
1 parent 1bf3a3e commit 2f8d0e2
Show file tree
Hide file tree
Showing 12 changed files with 410 additions and 266 deletions.
82 changes: 46 additions & 36 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import contextlib
import tempfile
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from typing import Any, Iterator, Callable

Expand Down Expand Up @@ -43,7 +42,6 @@ class RenderConfig:
node_converters: dict[DbtResourceType, Callable[..., Any]] | None = None


@dataclass
class ProjectConfig:
"""
Class for setting project config.
Expand All @@ -58,29 +56,41 @@ class ProjectConfig:
Required if dbt_project_path is not defined. Defaults to the folder name of dbt_project_path.
"""

dbt_project_path: str | Path | None = None
models_relative_path: str | Path = "models"
seeds_relative_path: str | Path = "seeds"
snapshots_relative_path: str | Path = "snapshots"
manifest_path: str | Path | None = None
project_name: str | None = None

@cached_property
def parsed_dbt_project_path(self) -> Path | None:
return Path(self.dbt_project_path) if self.dbt_project_path else None
dbt_project_path: Path | None = None
manifest_path: Path | None = None
models_path: Path | None = None
seeds_path: Path | None = None
snapshots_path: Path | None = None
project_name: str

def __init__(
self,
dbt_project_path: str | Path | None = None,
models_relative_path: str | Path = "models",
seeds_relative_path: str | Path = "seeds",
snapshots_relative_path: str | Path = "snapshots",
manifest_path: str | Path | None = None,
project_name: str | None = None,
):
if not dbt_project_path:
if not manifest_path or not project_name:
raise CosmosValueError(
"ProjectConfig requires dbt_project_path and/or manifest_path to be defined."
" If only manifest_path is defined, project_name must also be defined."
)
if project_name:
self.project_name = project_name

@cached_property
def parsed_manifest_path(self) -> Path | None:
return Path(self.manifest_path) if self.manifest_path else None
if dbt_project_path:
self.dbt_project_path = Path(dbt_project_path)
self.models_path = self.dbt_project_path / Path(models_relative_path)
self.seeds_path = self.dbt_project_path / Path(seeds_relative_path)
self.snapshots_path = self.dbt_project_path / Path(snapshots_relative_path)
if not project_name:
self.project_name = self.dbt_project_path.stem

def __post_init__(self) -> None:
"Converts paths to `Path` objects."
if self.parsed_dbt_project_path:
self.models_relative_path = self.parsed_dbt_project_path / Path(self.models_relative_path)
self.seeds_relative_path = self.parsed_dbt_project_path / Path(self.seeds_relative_path)
self.snapshots_relative_path = self.parsed_dbt_project_path / Path(self.snapshots_relative_path)
if not self.project_name:
self.project_name = self.parsed_dbt_project_path.stem
if manifest_path:
self.manifest_path = Path(manifest_path)

def validate_project(self) -> None:
"""
Expand All @@ -94,20 +104,14 @@ def validate_project(self) -> None:

mandatory_paths = {}

if self.parsed_dbt_project_path:
project_yml_path = self.parsed_dbt_project_path / "dbt_project.yml"
if self.dbt_project_path:
project_yml_path = self.dbt_project_path / "dbt_project.yml"
mandatory_paths = {
"dbt_project.yml": project_yml_path,
"models directory ": self.models_relative_path,
"models directory ": self.models_path,
}
elif self.parsed_manifest_path:
if not self.project_name:
raise CosmosValueError(
"project_name required when manifest_path is present and dbt_project_path is not."
)
mandatory_paths = {"manifest file": self.parsed_manifest_path}
else:
raise CosmosValueError("dbt_project_path or manifest_path are required parameters.")
if self.manifest_path:
mandatory_paths["manifest"] = self.manifest_path

for name, path in mandatory_paths.items():
if path is None or not Path(path).exists():
Expand All @@ -117,10 +121,10 @@ def is_manifest_available(self) -> bool:
"""
Check if the `dbt` project manifest is set and if the file exists.
"""
if not self.parsed_manifest_path:
if not self.manifest_path:
return False

return self.parsed_manifest_path.exists()
return self.manifest_path.exists()


@dataclass
Expand Down Expand Up @@ -159,6 +163,12 @@ def validate_profile(self) -> None:
if not self.profiles_yml_filepath and not self.profile_mapping:
raise CosmosValueError("Either profiles_yml_filepath or profile_mapping must be set to render a profile")

def is_profile_yml_available(self) -> bool:
"""
Check if the `dbt` profiles.yml file exists.
"""
return Path(self.profiles_yml_filepath).exists() if self.profiles_yml_filepath else False

@contextlib.contextmanager
def ensure_profile(
self, desired_profile_path: Path | None = None, use_mock_values: bool = False
Expand Down
36 changes: 16 additions & 20 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@

import inspect
from typing import Any, Callable
from pathlib import Path

from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup

from cosmos.airflow.graph import build_airflow_graph
from cosmos.dbt.graph import DbtGraph
from cosmos.dbt.project import DbtProject
from cosmos.dbt.selector import retrieve_by_label
from cosmos.config import ProjectConfig, ExecutionConfig, RenderConfig, ProfileConfig
from cosmos.exceptions import CosmosValueError
Expand Down Expand Up @@ -106,40 +104,38 @@ def __init__(
project_config.validate_project()

emit_datasets = render_config.emit_datasets
dbt_root_path = project_config.dbt_project_path.parent
dbt_project_name = project_config.dbt_project_path.name
dbt_models_dir = project_config.models_relative_path
dbt_seeds_dir = project_config.seeds_relative_path
dbt_snapshots_dir = project_config.snapshots_relative_path
test_behavior = render_config.test_behavior
select = render_config.select
exclude = render_config.exclude
dbt_deps = render_config.dbt_deps
execution_mode = execution_config.execution_mode
test_indirect_selection = execution_config.test_indirect_selection
load_mode = render_config.load_method
manifest_path = project_config.parsed_manifest_path
dbt_executable_path = execution_config.dbt_executable_path
node_converters = render_config.node_converters

if not project_config.dbt_project_path:
raise CosmosValueError("A Project Path in ProjectConfig is required for generating a Task Operators.")

profile_args = {}
if profile_config.profile_mapping:
profile_args = profile_config.profile_mapping.profile_args

if not operator_args:
operator_args = {}

dbt_project = DbtProject(
name=dbt_project_name,
root_dir=Path(dbt_root_path),
models_dir=Path(dbt_models_dir) if dbt_models_dir else None,
seeds_dir=Path(dbt_seeds_dir) if dbt_seeds_dir else None,
snapshots_dir=Path(dbt_snapshots_dir) if dbt_snapshots_dir else None,
manifest_path=manifest_path,
)

# Previously, we were creating a cosmos.dbt.project.DbtProject
# DbtProject has now been replaced with ProjectConfig directly
# since the interface of the two classes were effectively the same
# Under this previous implementation, we were passing:
# - name, root dir, models dir, snapshots dir and manifest path
# Internally in the dbtProject class, we were defaulting the profile_path
# To be root dir/profiles.yml
# To keep this logic working, if converter is given no ProfileConfig,
# we can create a default retaining this value to preserve this functionality.
# We may want to consider defaulting this value in our actual ProjceConfig class?
dbt_graph = DbtGraph(
project=dbt_project,
project=project_config,
exclude=exclude,
select=select,
dbt_cmd=dbt_executable_path,
Expand All @@ -152,7 +148,7 @@ def __init__(
task_args = {
**operator_args,
# the following args may be only needed for local / venv:
"project_dir": dbt_project.dir,
"project_dir": project_config.dbt_project_path,
"profile_config": profile_config,
"emit_datasets": emit_datasets,
}
Expand All @@ -169,7 +165,7 @@ def __init__(
task_args=task_args,
test_behavior=test_behavior,
test_indirect_selection=test_indirect_selection,
dbt_project_name=dbt_project.name,
dbt_project_name=project_config.project_name,
on_warning_callback=on_warning_callback,
node_converters=node_converters,
)
Loading

0 comments on commit 2f8d0e2

Please sign in to comment.