diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index d03af5f1f..b1dcb8747 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -83,7 +83,7 @@ def create_test_task_metadata( def create_task_metadata( - node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any], use_name_as_task_id_prefix: bool = True + node: DbtNode, execution_mode: ExecutionMode, args: dict[str, Any], use_task_group: bool = False ) -> TaskMetadata | None: """ Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node. @@ -106,9 +106,9 @@ def create_task_metadata( if hasattr(node.resource_type, "value") and node.resource_type in dbt_resource_to_class: if node.resource_type == DbtResourceType.MODEL: - if use_name_as_task_id_prefix: - task_id = f"{node.name}_run" - else: + task_id = f"{node.name}_run" + + if use_task_group is True: task_id = "run" else: task_id = f"{node.name}_{node.resource_type.value}" @@ -167,14 +167,18 @@ def build_airflow_graph( # The exception are the test nodes, since it would be too slow to run test tasks individually. # If test_behaviour=="after_each", each model task will be bundled with a test task, using TaskGroup for node_id, node in nodes.items(): + use_task_group = ( + node.resource_type == DbtResourceType.MODEL + and test_behavior == TestBehavior.AFTER_EACH + and node.has_test is True + ) + task_meta = create_task_metadata( - node=node, - execution_mode=execution_mode, - args=task_args, - use_name_as_task_id_prefix=test_behavior != TestBehavior.AFTER_EACH, + node=node, execution_mode=execution_mode, args=task_args, use_task_group=use_task_group ) + if task_meta and node.resource_type != DbtResourceType.TEST: - if node.resource_type == DbtResourceType.MODEL and test_behavior == TestBehavior.AFTER_EACH: + if use_task_group is True: with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group: task = create_airflow_task(task_meta, dag, task_group=model_task_group) test_meta = create_test_task_metadata( diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 07a55ad79..f83f490e2 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -50,6 +50,7 @@ class DbtNode: file_path: Path tags: list[str] = field(default_factory=lambda: []) config: dict[str, Any] = field(default_factory=lambda: {}) + has_test: bool = False class DbtGraph: @@ -262,6 +263,8 @@ def load_via_dbt_ls(self) -> None: self.nodes = nodes self.filtered_nodes = nodes + self.update_node_dependency() + logger.info("Total nodes: %i", len(self.nodes)) logger.info("Total filtered nodes: %i", len(self.nodes)) @@ -306,6 +309,8 @@ def load_via_custom_parser(self) -> None: project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude ) + self.update_node_dependency() + logger.info("Total nodes: %i", len(self.nodes)) logger.info("Total filtered nodes: %i", len(self.nodes)) @@ -335,11 +340,28 @@ def load_from_dbt_manifest(self) -> None: tags=node_dict["tags"], config=node_dict["config"], ) + nodes[node.unique_id] = node self.nodes = nodes self.filtered_nodes = select_nodes( project_dir=self.project.dir, nodes=nodes, select=self.select, exclude=self.exclude ) + + self.update_node_dependency() + logger.info("Total nodes: %i", len(self.nodes)) logger.info("Total filtered nodes: %i", len(self.nodes)) + + def update_node_dependency(self) -> None: + """ + This will update the property `has_text` if node has `dbt` test + + Updates in-place: + * self.filtered_nodes + """ + for _, node in self.filtered_nodes.items(): + if node.resource_type == DbtResourceType.TEST: + for node_id in node.depends_on: + if node_id in self.filtered_nodes: + self.filtered_nodes[node_id].has_test = True diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index 7b539bb5b..bd3777209 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -36,6 +36,7 @@ file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql", tags=["has_child"], config={"materialized": "view"}, + has_test=True, ) test_parent_node = DbtNode( name="test_parent", unique_id="test_parent", resource_type=DbtResourceType.TEST, depends_on=["parent"], file_path="" @@ -49,15 +50,8 @@ tags=["nightly"], config={"materialized": "table"}, ) -test_child_node = DbtNode( - name="test_child", - unique_id="test_child", - resource_type=DbtResourceType.TEST, - depends_on=["child"], - file_path="", -) -sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node, test_child_node] +sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node] sample_nodes = {node.unique_id: node for node in sample_nodes_list} @@ -93,21 +87,18 @@ def test_build_airflow_graph_with_after_each(): "seed_parent_seed", "parent.run", "parent.test", - "child.run", - "child.test", + "child_run", ] + assert topological_sort == expected_sort task_groups = dag.task_group_dict - assert len(task_groups) == 2 + assert len(task_groups) == 1 assert task_groups["parent"].upstream_task_ids == {"seed_parent_seed"} assert list(task_groups["parent"].children.keys()) == ["parent.run", "parent.test"] - assert task_groups["child"].upstream_task_ids == {"parent.test"} - assert list(task_groups["child"].children.keys()) == ["child.run", "child.test"] - assert len(dag.leaves) == 1 - assert dag.leaves[0].task_id == "child.test" + assert dag.leaves[0].task_id == "child_run" @pytest.mark.skipif( @@ -231,7 +222,7 @@ def test_create_task_metadata_model(caplog): assert metadata.arguments == {"models": "my_model"} -def test_create_task_metadata_model_use_name_as_task_id_prefix(caplog): +def test_create_task_metadata_model_use_task_group(caplog): child_node = DbtNode( name="my_model", unique_id="my_folder.my_model", @@ -241,14 +232,12 @@ def test_create_task_metadata_model_use_name_as_task_id_prefix(caplog): tags=[], config={}, ) - metadata = create_task_metadata( - child_node, execution_mode=ExecutionMode.LOCAL, args={}, use_name_as_task_id_prefix=False - ) + metadata = create_task_metadata(child_node, execution_mode=ExecutionMode.LOCAL, args={}, use_task_group=True) assert metadata.id == "run" -@pytest.mark.parametrize("use_name_as_task_id_prefix", (None, True, False)) -def test_create_task_metadata_seed(caplog, use_name_as_task_id_prefix): +@pytest.mark.parametrize("use_task_group", (None, True, False)) +def test_create_task_metadata_seed(caplog, use_task_group): sample_node = DbtNode( name="my_seed", unique_id="my_folder.my_seed", @@ -258,14 +247,14 @@ def test_create_task_metadata_seed(caplog, use_name_as_task_id_prefix): tags=[], config={}, ) - if use_name_as_task_id_prefix is None: + if use_task_group is None: metadata = create_task_metadata(sample_node, execution_mode=ExecutionMode.DOCKER, args={}) else: metadata = create_task_metadata( sample_node, execution_mode=ExecutionMode.DOCKER, args={}, - use_name_as_task_id_prefix=use_name_as_task_id_prefix, + use_task_group=use_task_group, ) assert metadata.id == "my_seed_seed" assert metadata.operator_class == "cosmos.operators.docker.DbtSeedDockerOperator" diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index 3a2a3eaa9..4dbdeb411 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -312,3 +312,32 @@ def test_load_via_load_via_custom_parser(pipeline_name): assert dbt_graph.nodes == dbt_graph.filtered_nodes # the custom parser does not add dbt test nodes assert len(dbt_graph.nodes) == 8 + + +@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency", return_value=None) +def test_update_node_dependency_called(mock_update_node_dependency): + dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST) + dbt_graph = DbtGraph(project=dbt_project) + dbt_graph.load() + + assert mock_update_node_dependency.called + + +def test_update_node_dependency_target_exist(): + dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST) + dbt_graph = DbtGraph(project=dbt_project) + dbt_graph.load() + + for _, nodes in dbt_graph.nodes.items(): + if nodes.resource_type == DbtResourceType.TEST: + for node_id in nodes.depends_on: + assert dbt_graph.nodes[node_id].has_test is True + + +def test_update_node_dependency_test_not_exist(): + dbt_project = DbtProject(name="jaffle_shop", root_dir=DBT_PROJECTS_ROOT_DIR, manifest_path=SAMPLE_MANIFEST) + dbt_graph = DbtGraph(project=dbt_project, exclude=["config.materialized:test"]) + dbt_graph.load_from_dbt_manifest() + + for _, nodes in dbt_graph.filtered_nodes.items(): + assert nodes.has_test is False