diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 6d276013d..c46b95c59 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -2,6 +2,7 @@ import os from typing import Any, Sequence, Tuple +from abc import ABCMeta, abstractmethod import yaml from airflow.models.baseoperator import BaseOperator @@ -15,14 +16,13 @@ logger = get_logger(__name__) -class DbtBaseOperator(BaseOperator): +class AbstractDbtBaseOperator(BaseOperator, metaclass=ABCMeta): """ Executes a dbt core cli command. :param project_dir: Which directory to look in for the dbt_project.yml file. Default is the current working directory and its parents. :param conn_id: The airflow connection to use as the target - :param base_cmd: dbt sub-command to run (i.e ls, seed, run, test, etc.) :param select: dbt optional argument that specifies which nodes to include. :param exclude: dbt optional argument that specifies which models to exclude. :param selector: dbt optional argument - the selector name to use, as defined in selectors.yml @@ -78,11 +78,15 @@ class DbtBaseOperator(BaseOperator): intercept_flag = True + @property + @abstractmethod + def base_cmd(self) -> list[str]: + """Override this property to set the dbt sub-command (i.e ls, seed, run, test, etc.) for the operator""" + def __init__( self, project_dir: str, conn_id: str | None = None, - base_cmd: list[str] | None = None, select: str | None = None, exclude: str | None = None, selector: str | None = None, @@ -109,7 +113,6 @@ def __init__( ) -> None: self.project_dir = project_dir self.conn_id = conn_id - self.base_cmd = base_cmd self.select = select self.exclude = exclude self.selector = selector @@ -203,6 +206,10 @@ def add_global_flags(self) -> list[str]: flags.append(f"--{global_boolean_flag.replace('_', '-')}") return flags + def add_cmd_flags(self) -> list[str]: + """Allows subclasses to override to add flags for their dbt command""" + return [] + def build_cmd( self, context: Context, @@ -212,8 +219,7 @@ def build_cmd( dbt_cmd.extend(self.dbt_cmd_global_flags) - if self.base_cmd: - dbt_cmd.extend(self.base_cmd) + dbt_cmd.extend(self.base_cmd) if self.indirect_selection: dbt_cmd += ["--indirect-selection", self.indirect_selection] @@ -231,3 +237,104 @@ def build_cmd( env = self.get_env(context) return dbt_cmd, env + + +class DbtLSMixin: + """ + Executes a dbt core ls command. + """ + + base_cmd = ["ls"] + ui_color = "#DBCDF6" + + +class DbtSeedMixin: + """ + Mixin for dbt seed operation command. + + :param full_refresh: whether to add the flag --full-refresh to the dbt seed command + """ + + base_cmd = ["seed"] + ui_color = "#F58D7E" + + template_fields: Sequence[str] = ("full_refresh",) + + def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: + self.full_refresh = full_refresh + super().__init__(**kwargs) + + def add_cmd_flags(self) -> list[str]: + flags = [] + if self.full_refresh is True: + flags.append("--full-refresh") + + return flags + + +class DbtSnapshotMixin: + """Mixin for a dbt snapshot command.""" + + base_cmd = ["snapshot"] + ui_color = "#964B00" + + +class DbtRunMixin: + """ + Mixin for dbt run command. + + :param full_refresh: whether to add the flag --full-refresh to the dbt seed command + """ + + base_cmd = ["run"] + ui_color = "#7352BA" + ui_fgcolor = "#F4F2FC" + + template_fields: Sequence[str] = ("full_refresh",) + + def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: + self.full_refresh = full_refresh + super().__init__(**kwargs) + + def add_cmd_flags(self) -> list[str]: + flags = [] + if self.full_refresh is True: + flags.append("--full-refresh") + + return flags + + +class DbtTestMixin: + """Mixin for dbt test command.""" + + base_cmd = ["test"] + ui_color = "#8194E0" + + +class DbtRunOperationMixin: + """ + Mixin for dbt run operation command. + + :param macro_name: name of macro to execute + :param args: Supply arguments to the macro. This dictionary will be mapped to the keyword arguments defined in the + selected macro. + """ + + ui_color = "#8194E0" + template_fields: Sequence[str] = ("args",) + + def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: Any) -> None: + self.macro_name = macro_name + self.args = args + super().__init__(**kwargs) + + @property + def base_cmd(self) -> list[str]: + return ["run-operation", self.macro_name] + + def add_cmd_flags(self) -> list[str]: + flags = [] + if self.args is not None: + flags.append("--args") + flags.append(yaml.dump(self.args)) + return flags diff --git a/cosmos/operators/docker.py b/cosmos/operators/docker.py index fb2e1c90c..dfe6b955d 100644 --- a/cosmos/operators/docker.py +++ b/cosmos/operators/docker.py @@ -2,11 +2,18 @@ from typing import Any, Callable, Sequence -import yaml from airflow.utils.context import Context from cosmos.log import get_logger -from cosmos.operators.base import DbtBaseOperator +from cosmos.operators.base import ( + AbstractDbtBaseOperator, + DbtRunMixin, + DbtSeedMixin, + DbtSnapshotMixin, + DbtTestMixin, + DbtLSMixin, + DbtRunOperationMixin, +) logger = get_logger(__name__) @@ -20,13 +27,15 @@ ) -class DbtDockerBaseOperator(DockerOperator, DbtBaseOperator): # type: ignore +class DbtDockerBaseOperator(DockerOperator, AbstractDbtBaseOperator): # type: ignore """ Executes a dbt core cli command in a Docker container. """ - template_fields: Sequence[str] = tuple(list(DbtBaseOperator.template_fields) + list(DockerOperator.template_fields)) + template_fields: Sequence[str] = tuple( + list(AbstractDbtBaseOperator.template_fields) + list(DockerOperator.template_fields) + ) intercept_flag = False @@ -57,85 +66,48 @@ def execute(self, context: Context) -> None: self.build_and_run_cmd(context=context) -class DbtLSDockerOperator(DbtDockerBaseOperator): +class DbtLSDockerOperator(DbtLSMixin, DbtDockerBaseOperator): """ Executes a dbt core ls command. """ - ui_color = "#DBCDF6" - - def __init__(self, **kwargs: str) -> None: - super().__init__(**kwargs) - self.base_cmd = ["ls"] - -class DbtSeedDockerOperator(DbtDockerBaseOperator): +class DbtSeedDockerOperator(DbtSeedMixin, DbtDockerBaseOperator): """ Executes a dbt core seed command. :param full_refresh: dbt optional arg - dbt will treat incremental models as table models """ - ui_color = "#F58D7E" - - def __init__(self, full_refresh: bool = False, **kwargs: str) -> None: - self.full_refresh = full_refresh - super().__init__(**kwargs) - self.base_cmd = ["seed"] + template_fields: Sequence[str] = DbtDockerBaseOperator.template_fields + DbtSeedMixin.template_fields # type: ignore[operator] - def add_cmd_flags(self) -> list[str]: - flags = [] - if self.full_refresh is True: - flags.append("--full-refresh") - return flags - - def execute(self, context: Context) -> None: - cmd_flags = self.add_cmd_flags() - self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) - - -class DbtSnapshotDockerOperator(DbtDockerBaseOperator): +class DbtSnapshotDockerOperator(DbtSnapshotMixin, DbtDockerBaseOperator): """ Executes a dbt core snapshot command. - """ - ui_color = "#964B00" - def __init__(self, **kwargs: str) -> None: - super().__init__(**kwargs) - self.base_cmd = ["snapshot"] - - -class DbtRunDockerOperator(DbtDockerBaseOperator): +class DbtRunDockerOperator(DbtRunMixin, DbtDockerBaseOperator): """ Executes a dbt core run command. """ - ui_color = "#7352BA" - ui_fgcolor = "#F4F2FC" - - def __init__(self, **kwargs: str) -> None: - super().__init__(**kwargs) - self.base_cmd = ["run"] + template_fields: Sequence[str] = DbtDockerBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator] -class DbtTestDockerOperator(DbtDockerBaseOperator): +class DbtTestDockerOperator(DbtTestMixin, DbtDockerBaseOperator): """ Executes a dbt core test command. """ - ui_color = "#8194E0" - def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: str) -> None: super().__init__(**kwargs) - self.base_cmd = ["test"] # as of now, on_warning_callback in docker executor does nothing self.on_warning_callback = on_warning_callback -class DbtRunOperationDockerOperator(DbtDockerBaseOperator): +class DbtRunOperationDockerOperator(DbtRunOperationMixin, DbtDockerBaseOperator): """ Executes a dbt core run-operation command. @@ -144,22 +116,4 @@ class DbtRunOperationDockerOperator(DbtDockerBaseOperator): selected macro. """ - ui_color = "#8194E0" - template_fields: Sequence[str] = ("args",) - - def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: str) -> None: - self.macro_name = macro_name - self.args = args - super().__init__(**kwargs) - self.base_cmd = ["run-operation", macro_name] - - def add_cmd_flags(self) -> list[str]: - flags = [] - if self.args is not None: - flags.append("--args") - flags.append(yaml.dump(self.args)) - return flags - - def execute(self, context: Context) -> None: - cmd_flags = self.add_cmd_flags() - self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) + template_fields: Sequence[str] = DbtDockerBaseOperator.template_fields + DbtRunOperationMixin.template_fields # type: ignore[operator] diff --git a/cosmos/operators/kubernetes.py b/cosmos/operators/kubernetes.py index b844716de..353d9c534 100644 --- a/cosmos/operators/kubernetes.py +++ b/cosmos/operators/kubernetes.py @@ -3,12 +3,19 @@ from os import PathLike from typing import Any, Callable, Sequence -import yaml from airflow.utils.context import Context, context_merge from cosmos.log import get_logger from cosmos.config import ProfileConfig -from cosmos.operators.base import DbtBaseOperator +from cosmos.operators.base import ( + AbstractDbtBaseOperator, + DbtRunMixin, + DbtSeedMixin, + DbtSnapshotMixin, + DbtTestMixin, + DbtLSMixin, + DbtRunOperationMixin, +) from airflow.models import TaskInstance from cosmos.dbt.parser.output import extract_log_issues @@ -37,14 +44,14 @@ ) -class DbtKubernetesBaseOperator(KubernetesPodOperator, DbtBaseOperator): # type: ignore +class DbtKubernetesBaseOperator(KubernetesPodOperator, AbstractDbtBaseOperator): # type: ignore """ Executes a dbt core cli command in a Kubernetes Pod. """ template_fields: Sequence[str] = tuple( - list(DbtBaseOperator.template_fields) + list(KubernetesPodOperator.template_fields) + list(AbstractDbtBaseOperator.template_fields) + list(KubernetesPodOperator.template_fields) ) intercept_flag = False @@ -94,77 +101,39 @@ def execute(self, context: Context) -> None: self.build_and_run_cmd(context=context) -class DbtLSKubernetesOperator(DbtKubernetesBaseOperator): +class DbtLSKubernetesOperator(DbtLSMixin, DbtKubernetesBaseOperator): """ Executes a dbt core ls command. """ - ui_color = "#DBCDF6" - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.base_cmd = ["ls"] - - -class DbtSeedKubernetesOperator(DbtKubernetesBaseOperator): +class DbtSeedKubernetesOperator(DbtSeedMixin, DbtKubernetesBaseOperator): """ Executes a dbt core seed command. - - :param full_refresh: dbt optional arg - dbt will treat incremental models as table models """ - ui_color = "#F58D7E" - - def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: - self.full_refresh = full_refresh - super().__init__(**kwargs) - self.base_cmd = ["seed"] - - def add_cmd_flags(self) -> list[str]: - flags = [] - if self.full_refresh is True: - flags.append("--full-refresh") - - return flags - - def execute(self, context: Context) -> None: - cmd_flags = self.add_cmd_flags() - self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) + template_fields: Sequence[str] = DbtKubernetesBaseOperator.template_fields + DbtSeedMixin.template_fields # type: ignore[operator] -class DbtSnapshotKubernetesOperator(DbtKubernetesBaseOperator): +class DbtSnapshotKubernetesOperator(DbtSnapshotMixin, DbtKubernetesBaseOperator): """ Executes a dbt core snapshot command. - """ - ui_color = "#964B00" - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.base_cmd = ["snapshot"] - -class DbtRunKubernetesOperator(DbtKubernetesBaseOperator): +class DbtRunKubernetesOperator(DbtRunMixin, DbtKubernetesBaseOperator): """ Executes a dbt core run command. """ - ui_color = "#7352BA" - ui_fgcolor = "#F4F2FC" - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.base_cmd = ["run"] + template_fields: Sequence[str] = DbtKubernetesBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator] -class DbtTestKubernetesOperator(DbtKubernetesBaseOperator): +class DbtTestKubernetesOperator(DbtTestMixin, DbtKubernetesBaseOperator): """ Executes a dbt core test command. """ - ui_color = "#8194E0" - def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: Any) -> None: if not on_warning_callback: super().__init__(**kwargs) @@ -203,8 +172,6 @@ def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwar super().__init__(**kwargs) - self.base_cmd = ["test"] - def _handle_warnings(self, context: Context) -> None: """ Handles warnings by extracting log issues, creating additional context, and calling the @@ -258,31 +225,9 @@ def _cleanup_pod(self, context: Context) -> None: task.cleanup(pod=task.pod, remote_pod=task.remote_pod) -class DbtRunOperationKubernetesOperator(DbtKubernetesBaseOperator): +class DbtRunOperationKubernetesOperator(DbtRunOperationMixin, DbtKubernetesBaseOperator): """ Executes a dbt core run-operation command. - - :param macro_name: name of macro to execute - :param args: Supply arguments to the macro. This dictionary will be mapped to the keyword arguments defined in the - selected macro. """ - ui_color = "#8194E0" - template_fields: Sequence[str] = ("args",) - - def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: Any) -> None: - self.macro_name = macro_name - self.args = args - super().__init__(**kwargs) - self.base_cmd = ["run-operation", macro_name] - - def add_cmd_flags(self) -> list[str]: - flags = [] - if self.args is not None: - flags.append("--args") - flags.append(yaml.dump(self.args)) - return flags - - def execute(self, context: Context) -> None: - cmd_flags = self.add_cmd_flags() - self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) + template_fields: Sequence[str] = DbtKubernetesBaseOperator.template_fields + DbtRunOperationMixin.template_fields # type: ignore[operator] diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index b0b572430..cc0e1f30b 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -11,7 +11,6 @@ import airflow import jinja2 -import yaml from airflow import DAG from airflow.compat.functools import cached_property from airflow.configuration import conf @@ -38,7 +37,15 @@ from cosmos.constants import DEFAULT_OPENLINEAGE_NAMESPACE, OPENLINEAGE_PRODUCER from cosmos.config import ProfileConfig from cosmos.log import get_logger -from cosmos.operators.base import DbtBaseOperator +from cosmos.operators.base import ( + AbstractDbtBaseOperator, + DbtRunMixin, + DbtSeedMixin, + DbtSnapshotMixin, + DbtTestMixin, + DbtLSMixin, + DbtRunOperationMixin, +) from cosmos.hooks.subprocess import ( FullOutputSubprocessHook, FullOutputSubprocessResult, @@ -80,7 +87,7 @@ class OperatorLineage: # type: ignore LINEAGE_NAMESPACE = os.getenv("OPENLINEAGE_NAMESPACE", DEFAULT_OPENLINEAGE_NAMESPACE) -class DbtLocalBaseOperator(DbtBaseOperator): +class DbtLocalBaseOperator(AbstractDbtBaseOperator): """ Executes a dbt core cli command locally. @@ -96,7 +103,7 @@ class DbtLocalBaseOperator(DbtBaseOperator): :param should_store_compiled_sql: If true, store the compiled SQL in the compiled_sql rendered template. """ - template_fields: Sequence[str] = DbtBaseOperator.template_fields + ("compiled_sql",) # type: ignore[operator] + template_fields: Sequence[str] = AbstractDbtBaseOperator.template_fields + ("compiled_sql",) # type: ignore[operator] template_fields_renderers = { "compiled_sql": "sql", } @@ -366,7 +373,7 @@ def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None return result def execute(self, context: Context) -> None: - self.build_and_run_cmd(context=context) + self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) def on_kill(self) -> None: if self.cancel_query_on_kill: @@ -377,100 +384,47 @@ def on_kill(self) -> None: self.subprocess_hook.send_sigterm() -class DbtLSLocalOperator(DbtLocalBaseOperator): +class DbtLSLocalOperator(DbtLSMixin, DbtLocalBaseOperator): """ Executes a dbt core ls command. """ - ui_color = "#DBCDF6" - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.base_cmd = ["ls"] - -class DbtSeedLocalOperator(DbtLocalBaseOperator): +class DbtSeedLocalOperator(DbtSeedMixin, DbtLocalBaseOperator): """ Executes a dbt core seed command. - - :param full_refresh: dbt optional arg - dbt will treat incremental models as table models """ - ui_color = "#F58D7E" - - template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields + ("full_refresh",) # type: ignore[operator] - - def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: - self.full_refresh = full_refresh - super().__init__(**kwargs) - self.base_cmd = ["seed"] + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields + DbtSeedMixin.template_fields # type: ignore[operator] - def add_cmd_flags(self) -> list[str]: - flags = [] - if self.full_refresh is True: - flags.append("--full-refresh") - return flags - - def execute(self, context: Context) -> None: - cmd_flags = self.add_cmd_flags() - self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) - - -class DbtSnapshotLocalOperator(DbtLocalBaseOperator): +class DbtSnapshotLocalOperator(DbtSnapshotMixin, DbtLocalBaseOperator): """ Executes a dbt core snapshot command. - """ - ui_color = "#964B00" - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.base_cmd = ["snapshot"] - -class DbtRunLocalOperator(DbtLocalBaseOperator): +class DbtRunLocalOperator(DbtRunMixin, DbtLocalBaseOperator): """ Executes a dbt core run command. """ - ui_color = "#7352BA" - ui_fgcolor = "#F4F2FC" - template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields + ("full_refresh",) # type: ignore[operator] - - def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None: - self.full_refresh = full_refresh - super().__init__(**kwargs) - self.base_cmd = ["run"] - - def add_cmd_flags(self) -> list[str]: - flags = [] - if self.full_refresh is True: - flags.append("--full-refresh") - return flags - - def execute(self, context: Context) -> None: - cmd_flags = self.add_cmd_flags() - self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator] -class DbtTestLocalOperator(DbtLocalBaseOperator): +class DbtTestLocalOperator(DbtTestMixin, DbtLocalBaseOperator): """ Executes a dbt core test command. :param on_warning_callback: A callback function called on warnings with additional Context variables "test_names" and "test_results" of type `List`. Each index in "test_names" corresponds to the same index in "test_results". """ - ui_color = "#8194E0" - def __init__( self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.base_cmd = ["test"] self.on_warning_callback = on_warning_callback def _handle_warnings(self, result: FullOutputSubprocessResult, context: Context) -> None: @@ -504,7 +458,7 @@ def execute(self, context: Context) -> None: self._handle_warnings(result, context) -class DbtRunOperationLocalOperator(DbtLocalBaseOperator): +class DbtRunOperationLocalOperator(DbtRunOperationMixin, DbtLocalBaseOperator): """ Executes a dbt core run-operation command. @@ -513,25 +467,7 @@ class DbtRunOperationLocalOperator(DbtLocalBaseOperator): selected macro. """ - ui_color = "#8194E0" - template_fields: Sequence[str] = ("args",) - - def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: Any) -> None: - self.macro_name = macro_name - self.args = args - super().__init__(**kwargs) - self.base_cmd = ["run-operation", macro_name] - - def add_cmd_flags(self) -> list[str]: - flags = [] - if self.args is not None: - flags.append("--args") - flags.append(yaml.dump(self.args)) - return flags - - def execute(self, context: Context) -> None: - cmd_flags = self.add_cmd_flags() - self.build_and_run_cmd(context=context, cmd_flags=cmd_flags) + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields + DbtRunOperationMixin.template_fields # type: ignore[operator] class DbtDocsLocalOperator(DbtLocalBaseOperator): @@ -541,13 +477,11 @@ class DbtDocsLocalOperator(DbtLocalBaseOperator): """ ui_color = "#8194E0" - required_files = ["index.html", "manifest.json", "graph.gpickle", "catalog.json"] + base_cmd = ["docs", "generate"] def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - self.base_cmd = ["docs", "generate"] - self.check_static_flag() def check_static_flag(self) -> None: diff --git a/cosmos/operators/virtualenv.py b/cosmos/operators/virtualenv.py index 4d6338e09..16033ea20 100644 --- a/cosmos/operators/virtualenv.py +++ b/cosmos/operators/virtualenv.py @@ -105,7 +105,7 @@ class DbtLSVirtualenvOperator(DbtVirtualenvBaseOperator, DbtLSLocalOperator): """ -class DbtSeedVirtualenvOperator(DbtVirtualenvBaseOperator, DbtSeedLocalOperator): +class DbtSeedVirtualenvOperator(DbtVirtualenvBaseOperator, DbtSeedLocalOperator): # type: ignore[misc] """ Executes a dbt core seed command within a Python Virtual Environment, that is created before running the dbt command and deleted just after. @@ -119,7 +119,7 @@ class DbtSnapshotVirtualenvOperator(DbtVirtualenvBaseOperator, DbtSnapshotLocalO """ -class DbtRunVirtualenvOperator(DbtVirtualenvBaseOperator, DbtRunLocalOperator): +class DbtRunVirtualenvOperator(DbtVirtualenvBaseOperator, DbtRunLocalOperator): # type: ignore[misc] """ Executes a dbt core run command within a Python Virtual Environment, that is created before running the dbt command and deleted just after. @@ -133,7 +133,7 @@ class DbtTestVirtualenvOperator(DbtVirtualenvBaseOperator, DbtTestLocalOperator) """ -class DbtRunOperationVirtualenvOperator(DbtVirtualenvBaseOperator, DbtRunOperationLocalOperator): +class DbtRunOperationVirtualenvOperator(DbtVirtualenvBaseOperator, DbtRunOperationLocalOperator): # type: ignore[misc] """ Executes a dbt core run-operation command within a Python Virtual Environment, that is created before running the dbt command and deleted just after. diff --git a/tests/operators/test_base.py b/tests/operators/test_base.py new file mode 100644 index 000000000..a46b51f7f --- /dev/null +++ b/tests/operators/test_base.py @@ -0,0 +1,50 @@ +import pytest + +from cosmos.operators.base import ( + AbstractDbtBaseOperator, + DbtLSMixin, + DbtSeedMixin, + DbtRunOperationMixin, + DbtTestMixin, + DbtSnapshotMixin, + DbtRunMixin, +) + + +def test_dbt_base_operator_is_abstract(): + """Tests that the abstract base operator cannot be instantiated since the base_cmd is not defined.""" + expected_error = "Can't instantiate abstract class AbstractDbtBaseOperator with abstract methods? base_cmd" + with pytest.raises(TypeError, match=expected_error): + AbstractDbtBaseOperator() + + +@pytest.mark.parametrize( + "dbt_command, dbt_operator_class", + [ + ("test", DbtTestMixin), + ("snapshot", DbtSnapshotMixin), + ("ls", DbtLSMixin), + ("seed", DbtSeedMixin), + ("run", DbtRunMixin), + ], +) +def test_dbt_mixin_base_cmd(dbt_command, dbt_operator_class): + assert [dbt_command] == dbt_operator_class.base_cmd + + +@pytest.mark.parametrize("dbt_operator_class", [DbtSeedMixin, DbtRunMixin]) +@pytest.mark.parametrize("full_refresh, expected_flags", [(True, ["--full-refresh"]), (False, [])]) +def test_dbt_mixin_add_cmd_flags_full_refresh(full_refresh, expected_flags, dbt_operator_class): + dbt_mixin = dbt_operator_class(full_refresh=full_refresh) + flags = dbt_mixin.add_cmd_flags() + assert flags == expected_flags + + +@pytest.mark.parametrize("args, expected_flags", [(None, []), ({"arg1": "val1"}, ["--args", "arg1: val1\n"])]) +def test_dbt_mixin_add_cmd_flags_run_operator(args, expected_flags): + macro_name = "some_macro" + run_operation = DbtRunOperationMixin(macro_name=macro_name, args=args) + assert run_operation.base_cmd == ["run-operation", "some_macro"] + + flags = run_operation.add_cmd_flags() + assert flags == expected_flags diff --git a/tests/operators/test_docker.py b/tests/operators/test_docker.py index 234878e07..520511b23 100644 --- a/tests/operators/test_docker.py +++ b/tests/operators/test_docker.py @@ -13,8 +13,12 @@ ) +class ConcreteDbtDockerBaseOperator(DbtDockerBaseOperator): + base_cmd = ["cmd"] + + def test_dbt_docker_operator_add_global_flags() -> None: - dbt_base_operator = DbtDockerBaseOperator( + dbt_base_operator = ConcreteDbtDockerBaseOperator( conn_id="my_airflow_connection", task_id="my-task", image="my_image", @@ -38,7 +42,7 @@ def test_dbt_docker_operator_get_env(p_context_to_airflow_vars: MagicMock) -> No """ If an end user passes in a """ - dbt_base_operator = DbtDockerBaseOperator( + dbt_base_operator = ConcreteDbtDockerBaseOperator( conn_id="my_airflow_connection", task_id="my-task", image="my_image", diff --git a/tests/operators/test_kubernetes.py b/tests/operators/test_kubernetes.py index 585b1ab32..638cff140 100644 --- a/tests/operators/test_kubernetes.py +++ b/tests/operators/test_kubernetes.py @@ -23,8 +23,12 @@ module_available = False +class ConcreteDbtKubernetesBaseOperator(DbtKubernetesBaseOperator): + base_cmd = ["cmd"] + + def test_dbt_kubernetes_operator_add_global_flags() -> None: - dbt_kube_operator = DbtKubernetesBaseOperator( + dbt_kube_operator = ConcreteDbtKubernetesBaseOperator( conn_id="my_airflow_connection", task_id="my-task", image="my_image", @@ -48,7 +52,7 @@ def test_dbt_kubernetes_operator_get_env(p_context_to_airflow_vars: MagicMock) - """ If an end user passes in a """ - dbt_kube_operator = DbtKubernetesBaseOperator( + dbt_kube_operator = ConcreteDbtKubernetesBaseOperator( conn_id="my_airflow_connection", task_id="my-task", image="my_image", diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index dd7d34a6d..aa4cb741f 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -68,8 +68,12 @@ def failing_test_dbt_project(tmp_path): tmp_dir.cleanup() +class ConcreteDbtLocalBaseOperator(DbtLocalBaseOperator): + base_cmd = ["cmd"] + + def test_dbt_base_operator_add_global_flags() -> None: - dbt_base_operator = DbtLocalBaseOperator( + dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", @@ -88,27 +92,25 @@ def test_dbt_base_operator_add_global_flags() -> None: def test_dbt_base_operator_add_user_supplied_flags() -> None: - dbt_base_operator = DbtLocalBaseOperator( + dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", - base_cmd=["run"], dbt_cmd_flags=["--full-refresh"], ) cmd, _ = dbt_base_operator.build_cmd( Context(execution_date=datetime(2023, 2, 15, 12, 30)), ) - assert cmd[-2] == "run" + assert cmd[-2] == "cmd" assert cmd[-1] == "--full-refresh" def test_dbt_base_operator_add_user_supplied_global_flags() -> None: - dbt_base_operator = DbtLocalBaseOperator( + dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", - base_cmd=["run"], dbt_cmd_global_flags=["--cache-selected-only"], ) @@ -116,7 +118,7 @@ def test_dbt_base_operator_add_user_supplied_global_flags() -> None: Context(execution_date=datetime(2023, 2, 15, 12, 30)), ) assert cmd[-2] == "--cache-selected-only" - assert cmd[-1] == "run" + assert cmd[-1] == "cmd" @pytest.mark.parametrize( @@ -124,11 +126,10 @@ def test_dbt_base_operator_add_user_supplied_global_flags() -> None: [None, "cautious", "buildable", "empty"], ) def test_dbt_base_operator_use_indirect_selection(indirect_selection_type) -> None: - dbt_base_operator = DbtLocalBaseOperator( + dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", - base_cmd=["run"], indirect_selection=indirect_selection_type, ) @@ -140,7 +141,7 @@ def test_dbt_base_operator_use_indirect_selection(indirect_selection_type) -> No assert cmd[-1] == indirect_selection_type else: assert cmd[0].endswith("dbt") - assert cmd[1] == "run" + assert cmd[1] == "cmd" @pytest.mark.parametrize( @@ -157,7 +158,7 @@ def test_dbt_base_operator_use_indirect_selection(indirect_selection_type) -> No ], ) def test_dbt_base_operator_exception_handling(skip_exception, exception_code_returned, expected_exception) -> None: - dbt_base_operator = DbtLocalBaseOperator( + dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", @@ -174,7 +175,7 @@ def test_dbt_base_operator_get_env(p_context_to_airflow_vars: MagicMock) -> None """ If an end user passes in a """ - dbt_base_operator = DbtLocalBaseOperator( + dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", @@ -292,7 +293,7 @@ class MockEvent: run = MockRun() job = MockJob() - dbt_base_operator = DbtLocalBaseOperator( + dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", @@ -307,7 +308,7 @@ class MockEvent: def test_run_operator_emits_events_without_openlineage_events_completes(caplog): - dbt_base_operator = DbtLocalBaseOperator( + dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", @@ -324,7 +325,7 @@ def test_run_operator_emits_events_without_openlineage_events_completes(caplog): def test_store_compiled_sql() -> None: - dbt_base_operator = DbtLocalBaseOperator( + dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", @@ -337,7 +338,7 @@ def test_store_compiled_sql() -> None: context=Context(execution_date=datetime(2023, 2, 15, 12, 30)), ) - dbt_base_operator = DbtLocalBaseOperator( + dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", @@ -399,7 +400,10 @@ def test_operator_execute_without_flags(mock_build_and_run_cmd, operator_class): **operator_class_kwargs.get(operator_class, {}), ) task.execute(context={}) - mock_build_and_run_cmd.assert_called_once_with(context={}) + if operator_class == DbtTestLocalOperator: + mock_build_and_run_cmd.assert_called_once_with(context={}) + else: + mock_build_and_run_cmd.assert_called_once_with(context={}, cmd_flags=[]) @patch("cosmos.operators.local.DbtLocalArtifactProcessor") @@ -407,7 +411,7 @@ def test_calculate_openlineage_events_completes_openlineage_errors(mock_processo instance = mock_processor.return_value instance.parse = MagicMock(side_effect=KeyError) caplog.set_level(logging.DEBUG) - dbt_base_operator = DbtLocalBaseOperator( + dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir=DBT_PROJ_DIR, diff --git a/tests/operators/test_virtualenv.py b/tests/operators/test_virtualenv.py index 13dba8f94..86796308b 100644 --- a/tests/operators/test_virtualenv.py +++ b/tests/operators/test_virtualenv.py @@ -18,6 +18,10 @@ ) +class ConcreteDbtVirtualenvBaseOperator(DbtVirtualenvBaseOperator): + base_cmd = ["cmd"] + + @patch("airflow.utils.python_virtualenv.execute_in_subprocess") @patch("cosmos.operators.virtualenv.DbtLocalBaseOperator.calculate_openlineage_events_completes") @patch("cosmos.operators.virtualenv.DbtLocalBaseOperator.store_compiled_sql") @@ -41,7 +45,7 @@ def test_run_command( password="fake_password", schema="fake_schema", ) - venv_operator = DbtVirtualenvBaseOperator( + venv_operator = ConcreteDbtVirtualenvBaseOperator( profile_config=profile_config, task_id="fake_task", install_deps=True,