From ceb1dac807e6d2b5f6bd300ee8e22bd98249ac5e Mon Sep 17 00:00:00 2001 From: Justin Bandoro Date: Tue, 31 Oct 2023 13:04:39 -0700 Subject: [PATCH] fix complexity of select_nodes_ids_by_intersection --- cosmos/dbt/selector.py | 104 +++++++++++++++++++++++++---------------- 1 file changed, 65 insertions(+), 39 deletions(-) diff --git a/cosmos/dbt/selector.py b/cosmos/dbt/selector.py index f195f225b..c7316dc75 100644 --- a/cosmos/dbt/selector.py +++ b/cosmos/dbt/selector.py @@ -2,7 +2,7 @@ from pathlib import Path import copy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from cosmos.constants import DbtResourceType from cosmos.exceptions import CosmosValueError @@ -84,72 +84,96 @@ def __repr__(self) -> str: return f"SelectorConfig(paths={self.paths}, tags={self.tags}, config={self.config}, other={self.other})" -def select_nodes_ids_by_intersection(nodes: dict[str, DbtNode], config: SelectorConfig) -> set[str]: +class NodeSelector: """ - Return a list of node ids which matches the configuration defined in config. + Class to select nodes based on a selector config. :param nodes: Dictionary mapping dbt nodes (node.unique_id to node) :param config: User-defined select statements - - References: - https://docs.getdbt.com/reference/node-selection/syntax - https://docs.getdbt.com/reference/node-selection/yaml-selectors """ - if config.is_empty: - return set(nodes.keys()) - selected_nodes = set() - visited_nodes = set() + def __init__(self, nodes: dict[str, DbtNode], config: SelectorConfig) -> None: + self.nodes = nodes + self.config = config + + def select_nodes_ids_by_intersection(self) -> set[str]: + """ + Return a list of node ids which matches the configuration defined in config. + + References: + https://docs.getdbt.com/reference/node-selection/syntax + https://docs.getdbt.com/reference/node-selection/yaml-selectors + """ + if self.config.is_empty: + return set(self.nodes.keys()) + + self.selected_nodes: set[str] = set() + self.visited_nodes: set[str] = set() + + for node_id, node in self.nodes.items(): + if self._should_include_node(node_id, node): + self.selected_nodes.add(node_id) + + return self.selected_nodes - def should_include_node(node_id: str, node: DbtNode) -> bool: + def _should_include_node(self, node_id: str, node: DbtNode) -> bool: "Checks if a single node should be included. Only runs once per node with caching." - if node_id in visited_nodes: - return node_id in selected_nodes + if node_id in self.visited_nodes: + return node_id in self.selected_nodes - visited_nodes.add(node_id) + self.visited_nodes.add(node_id) if node.resource_type == DbtResourceType.TEST: - node.tags = getattr(nodes.get(node.depends_on[0]), "tags", []) + node.tags = getattr(self.nodes.get(node.depends_on[0]), "tags", []) - if config.tags: - if not (set(config.tags) <= set(node.tags)): - return False + if not self._is_tags_subset(node): + return False node_config = {key: value for key, value in node.config.items() if key in SUPPORTED_CONFIG} - config_tags = config.config.get("tags") - if config_tags and config_tags not in node_config.get("tags", []): + + if not self._is_config_subset(node_config): return False # Remove 'tags' as they've already been filtered for - config_copy = copy.deepcopy(config.config) + config_copy = copy.deepcopy(self.config.config) config_copy.pop("tags", None) node_config.pop("tags", None) if not (config_copy.items() <= node_config.items()): return False - if config.paths: - for filter_path in config.paths: - if filter_path in node.file_path.parents or filter_path == node.file_path: - return True + if self.config.paths and not self._is_path_matching(node): + return False - # if it's a test coming from a schema.yml file, check the model's file_path - if node.resource_type == DbtResourceType.TEST and node.file_path.name == "schema.yml": - # try to get the corresponding model from node.depends_on - if len(node.depends_on) == 1: - model_node = nodes.get(node.depends_on[0]) - if model_node: - return should_include_node(node.depends_on[0], model_node) + return True + def _is_tags_subset(self, node: DbtNode) -> bool: + """Checks if the node's tags are a subset of the config's tags.""" + if not (set(self.config.tags) <= set(node.tags)): return False + return True + def _is_config_subset(self, node_config: dict[str, Any]) -> bool: + """Checks if the node's config is a subset of the config's config.""" + config_tags = self.config.config.get("tags") + if config_tags and config_tags not in node_config.get("tags", []): + return False return True - for node_id, node in nodes.items(): - if should_include_node(node_id, node): - selected_nodes.add(node_id) + def _is_path_matching(self, node: DbtNode) -> bool: + """Checks if the node's path is a subset of the config's paths.""" + for filter_path in self.config.paths: + if filter_path in node.file_path.parents or filter_path == node.file_path: + return True - return selected_nodes + # if it's a test coming from a schema.yml file, check the model's file_path + if node.resource_type == DbtResourceType.TEST and node.file_path.name == "schema.yml": + # try to get the corresponding model from node.depends_on + if len(node.depends_on) == 1: + model_node = self.nodes.get(node.depends_on[0]) + if model_node: + return self._should_include_node(node.depends_on[0], model_node) + return False def retrieve_by_label(statement_list: list[str], label: str) -> set[str]: @@ -204,7 +228,8 @@ def select_nodes( for statement in select: config = SelectorConfig(project_dir, statement) - select_ids = select_nodes_ids_by_intersection(nodes, config) + node_selector = NodeSelector(nodes, config) + select_ids = node_selector.select_nodes_ids_by_intersection() subset_ids = subset_ids.union(set(select_ids)) if select: @@ -215,7 +240,8 @@ def select_nodes( exclude_ids: set[str] = set() for statement in exclude: config = SelectorConfig(project_dir, statement) - exclude_ids = exclude_ids.union(set(select_nodes_ids_by_intersection(nodes, config))) + node_selector = NodeSelector(nodes, config) + exclude_ids = exclude_ids.union(set(node_selector.select_nodes_ids_by_intersection())) subset_ids = set(nodes_ids) - set(exclude_ids) return {id_: nodes[id_] for id_ in subset_ids}