diff --git a/cosmos/dbt/selector.py b/cosmos/dbt/selector.py index 1e7b42667..ba8ab6e3f 100644 --- a/cosmos/dbt/selector.py +++ b/cosmos/dbt/selector.py @@ -20,7 +20,8 @@ TAG_SELECTOR = "tag:" CONFIG_SELECTOR = "config." PLUS_SELECTOR = "+" -GRAPH_SELECTOR_REGEX = r"^([0-9]*\+)?([^\+]+)(\+[0-9]*)?$|" +AT_SELECTOR = "@" +GRAPH_SELECTOR_REGEX = r"^(@|[0-9]*\+)?([^\+]+)(\+[0-9]*)?$|" logger = get_logger(__name__) @@ -35,6 +36,7 @@ class GraphSelector: +model_d+ 2+model_e model_f+3 + @model_g +/path/to/model_g+ path:/path/to/model_h+ +tag:nightly @@ -46,6 +48,7 @@ class GraphSelector: node_name: str precursors: str | None descendants: str | None + at_operator: bool = False @property def precursors_depth(self) -> int: @@ -56,6 +59,8 @@ def precursors_depth(self) -> int: 0: if it shouldn't return any precursors >0: upperbound number of parent generations """ + if self.at_operator: + return -1 if not self.precursors: return 0 if self.precursors == "+": @@ -90,7 +95,13 @@ def parse(text: str) -> GraphSelector | None: precursors, node_name, descendants = regex_match.groups() if "/" in node_name and not node_name.startswith(PATH_SELECTOR): node_name = f"{PATH_SELECTOR}{node_name}" - return GraphSelector(node_name, precursors, descendants) + + at_operator = precursors == AT_SELECTOR + if at_operator: + precursors = None + descendants = "+" # @ implies all descendants + + return GraphSelector(node_name, precursors, descendants, at_operator) return None def select_node_precursors(self, nodes: dict[str, DbtNode], root_id: str, selected_nodes: set[str]) -> None: @@ -101,7 +112,7 @@ def select_node_precursors(self, nodes: dict[str, DbtNode], root_id: str, select :param root_id: Unique identifier of self.node_name :param selected_nodes: Set where precursor nodes will be added to. """ - if self.precursors: + if self.precursors or self.at_operator: depth = self.precursors_depth previous_generation = {root_id} processed_nodes = set() @@ -203,16 +214,39 @@ def filter_nodes(self, nodes: dict[str, DbtNode]) -> set[str]: root_id = node_by_name[self.node_name] root_nodes.add(root_id) else: - logger.warn(f"Selector {self.node_name} not found.") + logger.warning(f"Selector {self.node_name} not found.") return selected_nodes selected_nodes.update(root_nodes) - for root_id in root_nodes: - self.select_node_precursors(nodes, root_id, selected_nodes) - self.select_node_descendants(nodes, root_id, selected_nodes) + self._select_nodes(nodes, root_nodes, selected_nodes) + return selected_nodes + def _select_nodes(self, nodes: dict[str, DbtNode], root_nodes: set[str], selected_nodes: set[str]) -> None: + """ + Handle selection of nodes based on the graph selector configuration. + + :param nodes: dbt project nodes + :param root_nodes: Set of root node ids + :param selected_nodes: Set where selected nodes will be added to. + """ + if self.at_operator: + descendants: set[str] = set() + # First get all descendants + for root_id in root_nodes: + self.select_node_descendants(nodes, root_id, descendants) + selected_nodes.update(descendants) + + # Get ancestors for root nodes and all descendants + for node_id in root_nodes | descendants: + self.select_node_precursors(nodes, node_id, selected_nodes) + else: + # Normal selection + for root_id in root_nodes: + self.select_node_precursors(nodes, root_id, selected_nodes) + self.select_node_descendants(nodes, root_id, selected_nodes) + class SelectorConfig: """ diff --git a/docs/configuration/selecting-excluding.rst b/docs/configuration/selecting-excluding.rst index 01ee536b0..9ee778e51 100644 --- a/docs/configuration/selecting-excluding.rst +++ b/docs/configuration/selecting-excluding.rst @@ -17,6 +17,7 @@ The ``select`` and ``exclude`` parameters are lists, with values like the follow - ``config.materialized:table``: include/exclude models with the config ``materialized: table`` - ``path:analytics/tables``: include/exclude models in the ``analytics/tables`` directory - ``+node_name+1`` (graph operators): include/exclude the node with name ``node_name``, all its parents, and its first generation of children (`dbt graph selector docs `_) +- ``@node_name`` (@ operator): include/exclude the node with name ``node_name``, all its descendants, and all ancestors of those descendants. This is useful in CI environments where you want to build a model and all its descendants, but you need the ancestors of those descendants to exist first. - ``tag:my_tag,+node_name`` (intersection): include/exclude ``node_name`` and its parents if they have the tag ``my_tag`` (`dbt set operator docs `_) - ``['tag:first_tag', 'tag:second_tag']`` (union): include/exclude nodes that have either ``tag:first_tag`` or ``tag:second_tag`` @@ -91,6 +92,17 @@ Examples: ) ) +.. code-block:: python + + from cosmos import DbtDag, RenderConfig + + jaffle_shop = DbtDag( + render_config=RenderConfig( + select=["@my_model"], # selects my_model, all its descendants, + # and all ancestors needed to build those descendants + ) + ) + Using ``selector`` -------------------------------- .. note:: diff --git a/tests/dbt/test_selector.py b/tests/dbt/test_selector.py index 56f65dad0..4574bd255 100644 --- a/tests/dbt/test_selector.py +++ b/tests/dbt/test_selector.py @@ -508,3 +508,83 @@ def test_should_include_node_without_depends_on(selector_config): def test_select_using_graph_operators(select_statement, expected): selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=select_statement) assert sorted(selected.keys()) == expected + + +def test_select_nodes_by_at_operator(): + """Test basic @ operator selecting node, descendants and ancestors of all""" + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@parent"]) + expected = [ + "model.dbt-proj.another_grandparent_node", + "model.dbt-proj.child", + "model.dbt-proj.grandparent", + "model.dbt-proj.parent", + "model.dbt-proj.sibling1", + "model.dbt-proj.sibling2", + ] + assert sorted(selected.keys()) == expected + + +def test_select_nodes_by_at_operator_leaf_node(): + """Test @ operator on a leaf node (no descendants)""" + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@child"]) + expected = [ + "model.dbt-proj.another_grandparent_node", + "model.dbt-proj.child", + "model.dbt-proj.grandparent", + "model.dbt-proj.parent", + ] + assert sorted(selected.keys()) == expected + + +def test_select_nodes_by_at_operator_root_node(): + """Test @ operator on a root node (no ancestors)""" + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@grandparent"]) + expected = [ + "model.dbt-proj.another_grandparent_node", + "model.dbt-proj.child", + "model.dbt-proj.grandparent", + "model.dbt-proj.parent", + "model.dbt-proj.sibling1", + "model.dbt-proj.sibling2", + ] + assert sorted(selected.keys()) == expected + + +def test_select_nodes_by_at_operator_union(): + """Test @ operator union with another selector""" + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@child", "tag:has_child"]) + expected = [ + "model.dbt-proj.another_grandparent_node", + "model.dbt-proj.child", + "model.dbt-proj.grandparent", + "model.dbt-proj.parent", + ] + assert sorted(selected.keys()) == expected + + +def test_select_nodes_by_at_operator_with_path(): + """Test @ operator with a path""" + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@gen2/models"]) + expected = [ + "model.dbt-proj.another_grandparent_node", + "model.dbt-proj.child", + "model.dbt-proj.grandparent", + "model.dbt-proj.parent", + "model.dbt-proj.sibling1", + "model.dbt-proj.sibling2", + ] + assert sorted(selected.keys()) == expected + + +def test_select_nodes_by_at_operator_nonexistent_node(): + """Test @ operator with a node that doesn't exist""" + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@nonexistent"]) + expected = [] + assert sorted(selected.keys()) == expected + + +def test_exclude_with_at_operator(): + """Test excluding nodes selected by @ operator""" + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, exclude=["@parent"]) + expected = ["model.dbt-proj.orphaned"] + assert sorted(selected.keys()) == expected