Skip to content

Commit

Permalink
Refactor common executor constructors with test coverage (#774)
Browse files Browse the repository at this point in the history
I noticed in #771 that there was a lot of repeated class constructors in
order to add a new execution mode that is common among `local`, `docker`
and `kubernetes` and there is no test coverage for the constructors and
methods in some of the operators.

This PR attempts to make it easier to add new execution operators in the
future.

## Breaking Change?
None

There may be task UI color differences with the kuberentes/docker
operators, since now all of LS/Seed/Run etc. operators across execution
modes have the same task colors.
  • Loading branch information
jbandoro authored Jan 8, 2024
1 parent 9924d5d commit 380ac52
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 263 deletions.
119 changes: 113 additions & 6 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
from typing import Any, Sequence, Tuple
from abc import ABCMeta, abstractmethod

import yaml
from airflow.models.baseoperator import BaseOperator
Expand All @@ -15,14 +16,13 @@
logger = get_logger(__name__)


class DbtBaseOperator(BaseOperator):
class AbstractDbtBaseOperator(BaseOperator, metaclass=ABCMeta):
"""
Executes a dbt core cli command.
:param project_dir: Which directory to look in for the dbt_project.yml file. Default is the current working
directory and its parents.
:param conn_id: The airflow connection to use as the target
:param base_cmd: dbt sub-command to run (i.e ls, seed, run, test, etc.)
:param select: dbt optional argument that specifies which nodes to include.
:param exclude: dbt optional argument that specifies which models to exclude.
:param selector: dbt optional argument - the selector name to use, as defined in selectors.yml
Expand Down Expand Up @@ -78,11 +78,15 @@ class DbtBaseOperator(BaseOperator):

intercept_flag = True

@property
@abstractmethod
def base_cmd(self) -> list[str]:
"""Override this property to set the dbt sub-command (i.e ls, seed, run, test, etc.) for the operator"""

def __init__(
self,
project_dir: str,
conn_id: str | None = None,
base_cmd: list[str] | None = None,
select: str | None = None,
exclude: str | None = None,
selector: str | None = None,
Expand All @@ -109,7 +113,6 @@ def __init__(
) -> None:
self.project_dir = project_dir
self.conn_id = conn_id
self.base_cmd = base_cmd
self.select = select
self.exclude = exclude
self.selector = selector
Expand Down Expand Up @@ -203,6 +206,10 @@ def add_global_flags(self) -> list[str]:
flags.append(f"--{global_boolean_flag.replace('_', '-')}")
return flags

def add_cmd_flags(self) -> list[str]:
"""Allows subclasses to override to add flags for their dbt command"""
return []

def build_cmd(
self,
context: Context,
Expand All @@ -212,8 +219,7 @@ def build_cmd(

dbt_cmd.extend(self.dbt_cmd_global_flags)

if self.base_cmd:
dbt_cmd.extend(self.base_cmd)
dbt_cmd.extend(self.base_cmd)

if self.indirect_selection:
dbt_cmd += ["--indirect-selection", self.indirect_selection]
Expand All @@ -231,3 +237,104 @@ def build_cmd(
env = self.get_env(context)

return dbt_cmd, env


class DbtLSMixin:
"""
Executes a dbt core ls command.
"""

base_cmd = ["ls"]
ui_color = "#DBCDF6"


class DbtSeedMixin:
"""
Mixin for dbt seed operation command.
:param full_refresh: whether to add the flag --full-refresh to the dbt seed command
"""

base_cmd = ["seed"]
ui_color = "#F58D7E"

template_fields: Sequence[str] = ("full_refresh",)

def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)

def add_cmd_flags(self) -> list[str]:
flags = []
if self.full_refresh is True:
flags.append("--full-refresh")

return flags


class DbtSnapshotMixin:
"""Mixin for a dbt snapshot command."""

base_cmd = ["snapshot"]
ui_color = "#964B00"


class DbtRunMixin:
"""
Mixin for dbt run command.
:param full_refresh: whether to add the flag --full-refresh to the dbt seed command
"""

base_cmd = ["run"]
ui_color = "#7352BA"
ui_fgcolor = "#F4F2FC"

template_fields: Sequence[str] = ("full_refresh",)

def __init__(self, full_refresh: bool = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)

def add_cmd_flags(self) -> list[str]:
flags = []
if self.full_refresh is True:
flags.append("--full-refresh")

return flags


class DbtTestMixin:
"""Mixin for dbt test command."""

base_cmd = ["test"]
ui_color = "#8194E0"


class DbtRunOperationMixin:
"""
Mixin for dbt run operation command.
:param macro_name: name of macro to execute
:param args: Supply arguments to the macro. This dictionary will be mapped to the keyword arguments defined in the
selected macro.
"""

ui_color = "#8194E0"
template_fields: Sequence[str] = ("args",)

def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: Any) -> None:
self.macro_name = macro_name
self.args = args
super().__init__(**kwargs)

@property
def base_cmd(self) -> list[str]:
return ["run-operation", self.macro_name]

def add_cmd_flags(self) -> list[str]:
flags = []
if self.args is not None:
flags.append("--args")
flags.append(yaml.dump(self.args))
return flags
90 changes: 22 additions & 68 deletions cosmos/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@

from typing import Any, Callable, Sequence

import yaml
from airflow.utils.context import Context

from cosmos.log import get_logger
from cosmos.operators.base import DbtBaseOperator
from cosmos.operators.base import (
AbstractDbtBaseOperator,
DbtRunMixin,
DbtSeedMixin,
DbtSnapshotMixin,
DbtTestMixin,
DbtLSMixin,
DbtRunOperationMixin,
)

logger = get_logger(__name__)

Expand All @@ -20,13 +27,15 @@
)


class DbtDockerBaseOperator(DockerOperator, DbtBaseOperator): # type: ignore
class DbtDockerBaseOperator(DockerOperator, AbstractDbtBaseOperator): # type: ignore
"""
Executes a dbt core cli command in a Docker container.
"""

template_fields: Sequence[str] = tuple(list(DbtBaseOperator.template_fields) + list(DockerOperator.template_fields))
template_fields: Sequence[str] = tuple(
list(AbstractDbtBaseOperator.template_fields) + list(DockerOperator.template_fields)
)

intercept_flag = False

Expand Down Expand Up @@ -57,85 +66,48 @@ def execute(self, context: Context) -> None:
self.build_and_run_cmd(context=context)


class DbtLSDockerOperator(DbtDockerBaseOperator):
class DbtLSDockerOperator(DbtLSMixin, DbtDockerBaseOperator):
"""
Executes a dbt core ls command.
"""

ui_color = "#DBCDF6"

def __init__(self, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = ["ls"]


class DbtSeedDockerOperator(DbtDockerBaseOperator):
class DbtSeedDockerOperator(DbtSeedMixin, DbtDockerBaseOperator):
"""
Executes a dbt core seed command.
:param full_refresh: dbt optional arg - dbt will treat incremental models as table models
"""

ui_color = "#F58D7E"

def __init__(self, full_refresh: bool = False, **kwargs: str) -> None:
self.full_refresh = full_refresh
super().__init__(**kwargs)
self.base_cmd = ["seed"]
template_fields: Sequence[str] = DbtDockerBaseOperator.template_fields + DbtSeedMixin.template_fields # type: ignore[operator]

def add_cmd_flags(self) -> list[str]:
flags = []
if self.full_refresh is True:
flags.append("--full-refresh")

return flags

def execute(self, context: Context) -> None:
cmd_flags = self.add_cmd_flags()
self.build_and_run_cmd(context=context, cmd_flags=cmd_flags)


class DbtSnapshotDockerOperator(DbtDockerBaseOperator):
class DbtSnapshotDockerOperator(DbtSnapshotMixin, DbtDockerBaseOperator):
"""
Executes a dbt core snapshot command.
"""

ui_color = "#964B00"

def __init__(self, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = ["snapshot"]


class DbtRunDockerOperator(DbtDockerBaseOperator):
class DbtRunDockerOperator(DbtRunMixin, DbtDockerBaseOperator):
"""
Executes a dbt core run command.
"""

ui_color = "#7352BA"
ui_fgcolor = "#F4F2FC"

def __init__(self, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = ["run"]
template_fields: Sequence[str] = DbtDockerBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator]


class DbtTestDockerOperator(DbtDockerBaseOperator):
class DbtTestDockerOperator(DbtTestMixin, DbtDockerBaseOperator):
"""
Executes a dbt core test command.
"""

ui_color = "#8194E0"

def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwargs: str) -> None:
super().__init__(**kwargs)
self.base_cmd = ["test"]
# as of now, on_warning_callback in docker executor does nothing
self.on_warning_callback = on_warning_callback


class DbtRunOperationDockerOperator(DbtDockerBaseOperator):
class DbtRunOperationDockerOperator(DbtRunOperationMixin, DbtDockerBaseOperator):
"""
Executes a dbt core run-operation command.
Expand All @@ -144,22 +116,4 @@ class DbtRunOperationDockerOperator(DbtDockerBaseOperator):
selected macro.
"""

ui_color = "#8194E0"
template_fields: Sequence[str] = ("args",)

def __init__(self, macro_name: str, args: dict[str, Any] | None = None, **kwargs: str) -> None:
self.macro_name = macro_name
self.args = args
super().__init__(**kwargs)
self.base_cmd = ["run-operation", macro_name]

def add_cmd_flags(self) -> list[str]:
flags = []
if self.args is not None:
flags.append("--args")
flags.append(yaml.dump(self.args))
return flags

def execute(self, context: Context) -> None:
cmd_flags = self.add_cmd_flags()
self.build_and_run_cmd(context=context, cmd_flags=cmd_flags)
template_fields: Sequence[str] = DbtDockerBaseOperator.template_fields + DbtRunOperationMixin.template_fields # type: ignore[operator]
Loading

0 comments on commit 380ac52

Please sign in to comment.