Skip to content

Commit

Permalink
deprecated DbtProject in favor of ProjectConfig
Browse files Browse the repository at this point in the history
Stopped using hard-coded Profile Path for Tests and relying on default in old DbtProject - Source it from a ProfileConfig instead.
Dont Pass project_path to DbtNode if it doesnt exist - Allow DbtNode to compare paths based on relative too.
Changed Other DbtProject to LegacyDbtProject consistantly.
Updated tests to pass the fully qualified project path, rather than root and name.
  • Loading branch information
tabmra committed Oct 24, 2023
1 parent 9542235 commit 9e11bea
Show file tree
Hide file tree
Showing 9 changed files with 287 additions and 253 deletions.
83 changes: 47 additions & 36 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class RenderConfig:
node_converters: dict[DbtResourceType, Callable[..., Any]] | None = None


@dataclass
class ProjectConfig:
"""
Class for setting project config.
Expand All @@ -58,33 +57,45 @@ class ProjectConfig:
Required if dbt_project_path is not defined. Defaults to the folder name of dbt_project_path.
"""

dbt_project_path: str | Path | None = None
models_relative_path: str | Path = "models"
seeds_relative_path: str | Path = "seeds"
snapshots_relative_path: str | Path = "snapshots"
manifest_path: str | Path | None = None
project_name: str | None = None
dbt_project_path: Path | None = None
manifest_path: Path | None = None
models_path: Path | None = None
seeds_path: Path | None = None
snapshots_path: Path | None = None
project_name: str

def __init__(
self,
dbt_project_path: str | Path | None = None,
models_relative_path: str | Path = "models",
seeds_relative_path: str | Path = "seeds",
snapshots_relative_path: str | Path = "snapshots",
manifest_path: str | Path | None = None,
project_name: str | None = None,
):
if not dbt_project_path:
if not manifest_path or not project_name:
raise CosmosValueError(
"ProjectConfig requires dbt_project_path and/or manifest_path to be defined."
" If only manifest_path is defined, project_name must also be defined."
)
if project_name:
self.project_name = project_name

@cached_property
def parsed_dbt_project_path(self) -> Path | None:
return Path(self.dbt_project_path) if self.dbt_project_path else None
if dbt_project_path:
self.dbt_project_path = Path(dbt_project_path)
self.models_path = self.dbt_project_path / Path(models_relative_path)
self.seeds_path = self.dbt_project_path / Path(seeds_relative_path)
self.snapshots_path = self.dbt_project_path / Path(snapshots_relative_path)
if not project_name:
self.project_name = self.dbt_project_path.stem

@cached_property
def parsed_manifest_path(self) -> Path | None:
return Path(self.manifest_path) if self.manifest_path else None
if manifest_path:
self.manifest_path = Path(manifest_path)

@cached_property
def dbt_project_path_parent(self) -> Path | None:
return self.parsed_dbt_project_path.parent if self.parsed_dbt_project_path else None

def __post_init__(self) -> None:
"Converts paths to `Path` objects."
if self.parsed_dbt_project_path:
self.models_relative_path = self.parsed_dbt_project_path / Path(self.models_relative_path)
self.seeds_relative_path = self.parsed_dbt_project_path / Path(self.seeds_relative_path)
self.snapshots_relative_path = self.parsed_dbt_project_path / Path(self.snapshots_relative_path)
if not self.project_name:
self.project_name = self.parsed_dbt_project_path.stem
return self.dbt_project_path.parent if self.dbt_project_path else None

def validate_project(self) -> None:
"""
Expand All @@ -98,20 +109,14 @@ def validate_project(self) -> None:

mandatory_paths = {}

if self.parsed_dbt_project_path:
project_yml_path = self.parsed_dbt_project_path / "dbt_project.yml"
if self.dbt_project_path:
project_yml_path = self.dbt_project_path / "dbt_project.yml"
mandatory_paths = {
"dbt_project.yml": project_yml_path,
"models directory ": self.models_relative_path,
"models directory ": self.models_path,
}
elif self.parsed_manifest_path:
if not self.project_name:
raise CosmosValueError(
"project_name required when manifest_path is present and dbt_project_path is not."
)
mandatory_paths = {"manifest file": self.parsed_manifest_path}
else:
raise CosmosValueError("dbt_project_path or manifest_path are required parameters.")
if self.manifest_path:
mandatory_paths["manifest"] = self.manifest_path

for name, path in mandatory_paths.items():
if path is None or not Path(path).exists():
Expand All @@ -121,10 +126,10 @@ def is_manifest_available(self) -> bool:
"""
Check if the `dbt` project manifest is set and if the file exists.
"""
if not self.parsed_manifest_path:
if not self.manifest_path:
return False

return self.parsed_manifest_path.exists()
return self.manifest_path.exists()


@dataclass
Expand Down Expand Up @@ -163,6 +168,12 @@ def validate_profile(self) -> None:
if not self.profiles_yml_filepath and not self.profile_mapping:
raise CosmosValueError("Either profiles_yml_filepath or profile_mapping must be set to render a profile")

def is_profile_yml_available(self) -> bool:
"""
Check if the `dbt` profiles.yml file exists.
"""
return Path(self.profiles_yml_filepath).exists() if self.profiles_yml_filepath else False

@contextlib.contextmanager
def ensure_profile(
self, desired_profile_path: Path | None = None, use_mock_values: bool = False
Expand Down
37 changes: 17 additions & 20 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@

import inspect
from typing import Any, Callable
from pathlib import Path

from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup

from cosmos.airflow.graph import build_airflow_graph
from cosmos.dbt.graph import DbtGraph
from cosmos.dbt.project import DbtProject
from cosmos.dbt.selector import retrieve_by_label
from cosmos.config import ProjectConfig, ExecutionConfig, RenderConfig, ProfileConfig
from cosmos.exceptions import CosmosValueError
Expand Down Expand Up @@ -93,7 +91,7 @@ class DbtToAirflowConverter:
def __init__(
self,
project_config: ProjectConfig,
profile_config: ProfileConfig,
profile_config: ProfileConfig = None,
execution_config: ExecutionConfig = ExecutionConfig(),
render_config: RenderConfig = RenderConfig(),
dag: DAG | None = None,
Expand All @@ -106,39 +104,38 @@ def __init__(
project_config.validate_project()

emit_datasets = render_config.emit_datasets
dbt_project_name = project_config.project_name
dbt_models_dir = project_config.models_relative_path
dbt_seeds_dir = project_config.seeds_relative_path
dbt_snapshots_dir = project_config.snapshots_relative_path
test_behavior = render_config.test_behavior
select = render_config.select
exclude = render_config.exclude
dbt_deps = render_config.dbt_deps
execution_mode = execution_config.execution_mode
test_indirect_selection = execution_config.test_indirect_selection
load_mode = render_config.load_method
manifest_path = project_config.parsed_manifest_path
dbt_executable_path = execution_config.dbt_executable_path
node_converters = render_config.node_converters

if not profile_config:
profile_config = ProfileConfig(profiles_yml_filepath=project_config.dbt_project_path / "profiles.yml")

profile_args = {}
if profile_config.profile_mapping:
profile_args = profile_config.profile_mapping.profile_args

if not operator_args:
operator_args = {}

dbt_project = DbtProject(
name=dbt_project_name,
root_dir=project_config.dbt_project_path_parent,
models_dir=Path(dbt_models_dir) if dbt_models_dir else None,
seeds_dir=Path(dbt_seeds_dir) if dbt_seeds_dir else None,
snapshots_dir=Path(dbt_snapshots_dir) if dbt_snapshots_dir else None,
manifest_path=manifest_path,
)

# Previously, we were creating a cosmos.dbt.project.DbtProject
# DbtProject has now been replaced with ProjectConfig directly
# since the interface of the two classes were effectively the same
# Under this previous implementation, we were passing:
# - name, root dir, models dir, snapshots dir and manifest path
# Internally in the dbtProject class, we were defaulting the profile_path
# To be root dir/profiles.yml
# To keep this logic working, if converter is given no ProfileConfig,
# we can create a default retaining this value to preserve this functionality.
# We may want to consider defaulting this value in our actual ProjceConfig class?
dbt_graph = DbtGraph(
project=dbt_project,
project=project_config,
exclude=exclude,
select=select,
dbt_cmd=dbt_executable_path,
Expand All @@ -151,7 +148,7 @@ def __init__(
task_args = {
**operator_args,
# the following args may be only needed for local / venv:
"project_dir": dbt_project.dir,
"project_dir": project_config.dbt_project_path,
"profile_config": profile_config,
"emit_datasets": emit_datasets,
}
Expand All @@ -168,7 +165,7 @@ def __init__(
task_args=task_args,
test_behavior=test_behavior,
test_indirect_selection=test_indirect_selection,
dbt_project_name=dbt_project.name,
dbt_project_name=project_config.project_name,
on_warning_callback=on_warning_callback,
node_converters=node_converters,
)
61 changes: 37 additions & 24 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from subprocess import PIPE, Popen
from typing import Any

from cosmos.config import ProfileConfig
from cosmos.config import ProfileConfig, ProjectConfig
from cosmos.constants import (
DBT_LOG_DIR_NAME,
DBT_LOG_FILENAME,
Expand All @@ -22,8 +22,7 @@
LoadMode,
)
from cosmos.dbt.executable import get_system_dbt
from cosmos.dbt.parser.project import DbtProject as LegacyDbtProject
from cosmos.dbt.project import DbtProject
from cosmos.dbt.parser.project import LegacyDbtProject
from cosmos.dbt.selector import select_nodes
from cosmos.log import get_logger

Expand Down Expand Up @@ -64,7 +63,7 @@ class DbtGraph:
Example of how to use:
dbt_graph = DbtGraph(
project=DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR),
project=ProjectConfig(dbt_project_path=DBT_PROJECT_PATH),
exclude=["*orders*"],
select=[],
dbt_cmd="/usr/local/bin/dbt",
Expand All @@ -77,11 +76,11 @@ class DbtGraph:

def __init__(
self,
project: DbtProject,
project: ProjectConfig,
profile_config: ProfileConfig | None = None,
exclude: list[str] | None = None,
select: list[str] | None = None,
dbt_cmd: str = get_system_dbt(),
profile_config: ProfileConfig | None = None,
operator_args: dict[str, Any] | None = None,
dbt_deps: bool | None = True,
):
Expand Down Expand Up @@ -122,7 +121,11 @@ def load(
if self.project.is_manifest_available():
self.load_from_dbt_manifest()
else:
if execution_mode == ExecutionMode.LOCAL and self.project.is_profile_yml_available():
if (
execution_mode == ExecutionMode.LOCAL
and self.profile_config
and self.profile_config.is_profile_yml_available()
):
try:
self.load_via_dbt_ls()
except FileNotFoundError:
Expand All @@ -144,9 +147,13 @@ def load_via_dbt_ls(self) -> None:
* self.nodes
* self.filtered_nodes
"""
logger.info("Trying to parse the dbt project `%s` in `%s` using dbt ls...", self.project.name, self.project.dir)
logger.info(
"Trying to parse the dbt project `%s` in `%s` using dbt ls...",
self.project.project_name,
self.project.dbt_project_path,
)

if not self.project.dir or not self.profile_config:
if not self.project.dbt_project_path or not self.profile_config:
raise CosmosLoadDbtException("Unable to load dbt project without project files and a profile config")

if not shutil.which(self.dbt_cmd):
Expand All @@ -158,16 +165,20 @@ def load_via_dbt_ls(self) -> None:
env.update(env_vars)

with tempfile.TemporaryDirectory() as tmpdir:
logger.info("Content of the dbt project dir <%s>: `%s`", self.project.dir, os.listdir(self.project.dir))
logger.info("Creating symlinks from %s to `%s`", self.project.dir, tmpdir)
logger.info(
"Content of the dbt project dir <%s>: `%s`",
self.project.dbt_project_path,
os.listdir(self.project.dbt_project_path),
)
logger.info("Creating symlinks from %s to `%s`", self.project.dbt_project_path, tmpdir)
# We create symbolic links to the original directory files and directories.
# This allows us to run the dbt command from within the temporary directory, outputting any necessary
# artifact and also allow us to run `dbt deps`
tmpdir_path = Path(tmpdir)
ignore_paths = (DBT_LOG_DIR_NAME, DBT_TARGET_DIR_NAME, "dbt_packages", "profiles.yml")
for child_name in os.listdir(self.project.dir):
for child_name in os.listdir(self.project.dbt_project_path):
if child_name not in ignore_paths:
os.symlink(self.project.dir / child_name, tmpdir_path / child_name)
os.symlink(self.project.dbt_project_path / child_name, tmpdir_path / child_name)

local_flags = [
"--project-dir",
Expand Down Expand Up @@ -259,7 +270,7 @@ def load_via_dbt_ls(self) -> None:
unique_id=node_dict["unique_id"],
resource_type=DbtResourceType(node_dict["resource_type"]),
depends_on=node_dict.get("depends_on", {}).get("nodes", []),
file_path=self.project.dir / node_dict["original_file_path"],
file_path=self.project.dbt_project_path / node_dict["original_file_path"],
tags=node_dict["tags"],
config=node_dict["config"],
)
Expand All @@ -286,16 +297,16 @@ def load_via_custom_parser(self) -> None:
* self.nodes
* self.filtered_nodes
"""
logger.info("Trying to parse the dbt project `%s` using a custom Cosmos method...", self.project.name)
logger.info("Trying to parse the dbt project `%s` using a custom Cosmos method...", self.project.project_name)

if not self.project.dir:
if not self.project.dbt_project_path or not self.project.models_path or not self.project.seeds_path:
raise CosmosLoadDbtException("Unable to load dbt project without project files")

project = LegacyDbtProject(
dbt_root_path=str(self.project.root_dir),
dbt_models_dir=self.project.models_dir.stem if self.project.models_dir else None,
dbt_seeds_dir=self.project.seeds_dir.stem if self.project.seeds_dir else None,
project_name=self.project.name,
project_name=self.project.dbt_project_path.stem,
dbt_root_path=self.project.dbt_project_path.parent.as_posix(),
dbt_models_dir=self.project.models_path.stem,
dbt_seeds_dir=self.project.seeds_path.stem,
operator_args=self.operator_args,
)
nodes = {}
Expand All @@ -317,7 +328,7 @@ def load_via_custom_parser(self) -> None:

self.nodes = nodes
self.filtered_nodes = select_nodes(
project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude
project_dir=self.project.dbt_project_path, nodes=nodes, select=self.select, exclude=self.exclude
)

self.update_node_dependency()
Expand All @@ -339,7 +350,7 @@ def load_from_dbt_manifest(self) -> None:
* self.nodes
* self.filtered_nodes
"""
logger.info("Trying to parse the dbt project `%s` using a dbt manifest...", self.project.name)
logger.info("Trying to parse the dbt project `%s` using a dbt manifest...", self.project.project_name)

if not self.project.is_manifest_available():
raise CosmosLoadDbtException(f"Unable to load manifest using {self.project.manifest_path}")
Expand All @@ -355,7 +366,9 @@ def load_from_dbt_manifest(self) -> None:
unique_id=unique_id,
resource_type=DbtResourceType(node_dict["resource_type"]),
depends_on=node_dict.get("depends_on", {}).get("nodes", []),
file_path=self.project.dir / node_dict["original_file_path"],
file_path=self.project.dbt_project_path / Path(node_dict["original_file_path"])
if self.project.dbt_project_path
else Path(node_dict["original_file_path"]),
tags=node_dict["tags"],
config=node_dict["config"],
)
Expand All @@ -364,7 +377,7 @@ def load_from_dbt_manifest(self) -> None:

self.nodes = nodes
self.filtered_nodes = select_nodes(
project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude
project_dir=self.project.dbt_project_path, nodes=nodes, select=self.select, exclude=self.exclude
)

self.update_node_dependency()
Expand Down
2 changes: 1 addition & 1 deletion cosmos/dbt/parser/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def __repr__(self) -> str:


@dataclass
class DbtProject:
class LegacyDbtProject:
"""
Represents a single dbt project.
"""
Expand Down
Loading

0 comments on commit 9e11bea

Please sign in to comment.