Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix running dbt tests that depend on multiple models (support --indirect-selection buildable) #613

Merged
merged 7 commits into from
Oct 23, 2023
22 changes: 20 additions & 2 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup

from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode, TESTABLE_DBT_RESOURCES, DEFAULT_DBT_RESOURCES
from cosmos.constants import (
DbtResourceType,
TestBehavior,
TestIndirectSelection,
ExecutionMode,
TESTABLE_DBT_RESOURCES,
DEFAULT_DBT_RESOURCES,
)
from cosmos.core.airflow import get_airflow_task as create_airflow_task
from cosmos.core.graph.entities import Task as TaskMetadata
from cosmos.dbt.graph import DbtNode
Expand Down Expand Up @@ -54,6 +61,7 @@ def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[st
def create_test_task_metadata(
test_task_name: str,
execution_mode: ExecutionMode,
test_indirect_selection: TestIndirectSelection,
task_args: dict[str, Any],
on_warning_callback: Callable[..., Any] | None = None,
node: DbtNode | None = None,
Expand All @@ -71,6 +79,8 @@ def create_test_task_metadata(
"""
task_args = dict(task_args)
task_args["on_warning_callback"] = on_warning_callback
if test_indirect_selection != TestIndirectSelection.EAGER:
task_args["indirect_selection"] = test_indirect_selection.value
if node is not None:
if node.resource_type == DbtResourceType.MODEL:
task_args["models"] = node.name
Expand Down Expand Up @@ -144,6 +154,7 @@ def generate_task_or_group(
execution_mode: ExecutionMode,
task_args: dict[str, Any],
test_behavior: TestBehavior,
test_indirect_selection: TestIndirectSelection,
on_warning_callback: Callable[..., Any] | None,
**kwargs: Any,
) -> BaseOperator | TaskGroup | None:
Expand All @@ -169,6 +180,7 @@ def generate_task_or_group(
test_meta = create_test_task_metadata(
"test",
execution_mode,
test_indirect_selection,
task_args=task_args,
node=node,
on_warning_callback=on_warning_callback,
Expand All @@ -187,6 +199,7 @@ def build_airflow_graph(
execution_mode: ExecutionMode, # Cosmos-specific - decide what which class to use
task_args: dict[str, Any], # Cosmos/DBT - used to instantiate tasks
test_behavior: TestBehavior, # Cosmos-specific: how to inject tests to Airflow DAG
test_indirect_selection: TestIndirectSelection, # Cosmos/DBT - used to set test indirect selection mode
dbt_project_name: str, # DBT / Cosmos - used to name test task if mode is after_all,
task_group: TaskGroup | None = None,
on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command
Expand Down Expand Up @@ -235,6 +248,7 @@ def build_airflow_graph(
execution_mode=execution_mode,
task_args=task_args,
test_behavior=test_behavior,
test_indirect_selection=test_indirect_selection,
on_warning_callback=on_warning_callback,
node=node,
)
Expand All @@ -246,7 +260,11 @@ def build_airflow_graph(
# The end of a DAG is defined by the DAG leaf tasks (tasks which do not have downstream tasks)
if test_behavior == TestBehavior.AFTER_ALL:
test_meta = create_test_task_metadata(
f"{dbt_project_name}_test", execution_mode, task_args=task_args, on_warning_callback=on_warning_callback
f"{dbt_project_name}_test",
execution_mode,
test_indirect_selection,
task_args=task_args,
on_warning_callback=on_warning_callback,
)
test_task = create_airflow_task(test_meta, dag, task_group=task_group)
leaves_ids = calculate_leaves(tasks_ids=list(tasks_map.keys()), nodes=nodes)
Expand Down
4 changes: 3 additions & 1 deletion cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path
from typing import Any, Iterator, Callable

from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode, LoadMode
from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode, LoadMode, TestIndirectSelection
from cosmos.dbt.executable import get_system_dbt
from cosmos.exceptions import CosmosValueError
from cosmos.log import get_logger
Expand Down Expand Up @@ -205,9 +205,11 @@ class ExecutionConfig:
Contains configuration about how to execute dbt.

:param execution_mode: The execution mode for dbt. Defaults to local
:param test_indirect_selection: The mode to configure the test behavior when performing indirect selection.
:param dbt_executable_path: The path to the dbt executable. Defaults to dbt if
available on the path.
"""

execution_mode: ExecutionMode = ExecutionMode.LOCAL
test_indirect_selection: TestIndirectSelection = TestIndirectSelection.EAGER
dbt_executable_path: str | Path = get_system_dbt()
11 changes: 11 additions & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ class ExecutionMode(Enum):
VIRTUALENV = "virtualenv"


class TestIndirectSelection(Enum):
"""
Modes to configure the test behavior when performing indirect selection.
"""

EAGER = "eager"
CAUTIOUS = "cautious"
BUILDABLE = "buildable"
EMPTY = "empty"


class DbtResourceType(aenum.Enum): # type: ignore
"""
Type of dbt node.
Expand Down
2 changes: 2 additions & 0 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
exclude = render_config.exclude
dbt_deps = render_config.dbt_deps
execution_mode = execution_config.execution_mode
test_indirect_selection = execution_config.test_indirect_selection
load_mode = render_config.load_method
manifest_path = project_config.parsed_manifest_path
dbt_executable_path = execution_config.dbt_executable_path
Expand Down Expand Up @@ -167,6 +168,7 @@ def __init__(
execution_mode=execution_mode,
task_args=task_args,
test_behavior=test_behavior,
test_indirect_selection=test_indirect_selection,
dbt_project_name=dbt_project.name,
on_warning_callback=on_warning_callback,
node_converters=node_converters,
Expand Down
5 changes: 5 additions & 0 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
vars: dict[str, str] | None = None,
models: str | None = None,
emit_datasets: bool = True,
indirect_selection: str | None = None,
cache_selected_only: bool = False,
no_version_check: bool = False,
fail_fast: bool = False,
Expand All @@ -115,6 +116,7 @@ def __init__(
self.vars = vars
self.models = models
self.emit_datasets = emit_datasets
self.indirect_selection = indirect_selection
self.cache_selected_only = cache_selected_only
self.no_version_check = no_version_check
self.fail_fast = fail_fast
Expand Down Expand Up @@ -213,6 +215,9 @@ def build_cmd(
if self.base_cmd:
dbt_cmd.extend(self.base_cmd)

if self.indirect_selection:
dbt_cmd += ["--indirect-selection", self.indirect_selection]

dbt_cmd.extend(self.add_global_flags())

# add command specific flags
Expand Down
41 changes: 33 additions & 8 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
generate_task_or_group,
)
from cosmos.config import ProfileConfig
from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior
from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior, TestIndirectSelection
from cosmos.dbt.graph import DbtNode
from cosmos.profiles import PostgresUserPasswordProfileMapping

Expand Down Expand Up @@ -80,6 +80,7 @@ def test_build_airflow_graph_with_after_each():
nodes=sample_nodes,
dag=dag,
execution_mode=ExecutionMode.LOCAL,
test_indirect_selection=TestIndirectSelection.EAGER,
task_args=task_args,
test_behavior=TestBehavior.AFTER_EACH,
dbt_project_name="astro_shop",
Expand Down Expand Up @@ -129,6 +130,7 @@ def test_create_task_group_for_after_each_supported_nodes(node_type, task_suffix
task_group=None,
node=node,
execution_mode=ExecutionMode.LOCAL,
test_indirect_selection=TestIndirectSelection.EAGER,
task_args={
"project_dir": SAMPLE_PROJ_PATH,
"profile_config": ProfileConfig(
Expand Down Expand Up @@ -170,6 +172,7 @@ def test_build_airflow_graph_with_after_all():
nodes=sample_nodes,
dag=dag,
execution_mode=ExecutionMode.LOCAL,
test_indirect_selection=TestIndirectSelection.EAGER,
task_args=task_args,
test_behavior=TestBehavior.AFTER_ALL,
dbt_project_name="astro_shop",
Expand Down Expand Up @@ -332,15 +335,30 @@ def test_create_task_metadata_snapshot(caplog):


@pytest.mark.parametrize(
"node_type,node_unique_id,selector_key,selector_value",
"node_type,node_unique_id,test_indirect_selection,additional_arguments",
[
(DbtResourceType.MODEL, "node_name", "models", "node_name"),
(DbtResourceType.SEED, "node_name", "select", "node_name"),
(DbtResourceType.SOURCE, "source.node_name", "select", "source:node_name"),
(DbtResourceType.SNAPSHOT, "node_name", "select", "node_name"),
(DbtResourceType.MODEL, "node_name", TestIndirectSelection.EAGER, {"models": "node_name"}),
(
DbtResourceType.SEED,
"node_name",
TestIndirectSelection.CAUTIOUS,
{"select": "node_name", "indirect_selection": "cautious"},
),
(
DbtResourceType.SOURCE,
"source.node_name",
TestIndirectSelection.BUILDABLE,
{"select": "source:node_name", "indirect_selection": "buildable"},
),
(
DbtResourceType.SNAPSHOT,
"node_name",
TestIndirectSelection.EMPTY,
{"select": "node_name", "indirect_selection": "empty"},
),
],
)
def test_create_test_task_metadata(node_type, node_unique_id, selector_key, selector_value):
def test_create_test_task_metadata(node_type, node_unique_id, test_indirect_selection, additional_arguments):
sample_node = DbtNode(
name="node_name",
unique_id=node_unique_id,
Expand All @@ -353,10 +371,17 @@ def test_create_test_task_metadata(node_type, node_unique_id, selector_key, sele
metadata = create_test_task_metadata(
test_task_name="test_no_nulls",
execution_mode=ExecutionMode.LOCAL,
test_indirect_selection=test_indirect_selection,
task_args={"task_arg": "value"},
on_warning_callback=True,
node=sample_node,
)
assert metadata.id == "test_no_nulls"
assert metadata.operator_class == "cosmos.operators.local.DbtTestLocalOperator"
assert metadata.arguments == {"task_arg": "value", "on_warning_callback": True, selector_key: selector_value}
assert metadata.arguments == {
**{
"task_arg": "value",
"on_warning_callback": True,
},
**additional_arguments,
}
23 changes: 23 additions & 0 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,29 @@ def test_dbt_base_operator_add_user_supplied_global_flags() -> None:
assert cmd[-1] == "run"


@pytest.mark.parametrize(
"indirect_selection_type",
[None, "cautious", "buildable", "empty"],
)
def test_dbt_base_operator_use_indirect_selection(indirect_selection_type) -> None:
dbt_base_operator = DbtLocalBaseOperator(
profile_config=profile_config,
task_id="my-task",
project_dir="my/dir",
base_cmd=["run"],
indirect_selection=indirect_selection_type,
)

cmd, _ = dbt_base_operator.build_cmd(
Context(execution_date=datetime(2023, 2, 15, 12, 30)),
)
if indirect_selection_type:
assert cmd[-2] == "--indirect-selection"
assert cmd[-1] == indirect_selection_type
else:
assert cmd == ["dbt", "run"]


@pytest.mark.parametrize(
["skip_exception", "exception_code_returned", "expected_exception"],
[
Expand Down
Loading