Skip to content

Commit

Permalink
fix complexity of select_nodes_ids_by_intersection
Browse files Browse the repository at this point in the history
  • Loading branch information
jbandoro committed Oct 31, 2023
1 parent 9d853e2 commit ceb1dac
Showing 1 changed file with 65 additions and 39 deletions.
104 changes: 65 additions & 39 deletions cosmos/dbt/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
import copy

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from cosmos.constants import DbtResourceType
from cosmos.exceptions import CosmosValueError
Expand Down Expand Up @@ -84,72 +84,96 @@ def __repr__(self) -> str:
return f"SelectorConfig(paths={self.paths}, tags={self.tags}, config={self.config}, other={self.other})"


def select_nodes_ids_by_intersection(nodes: dict[str, DbtNode], config: SelectorConfig) -> set[str]:
class NodeSelector:
"""
Return a list of node ids which matches the configuration defined in config.
Class to select nodes based on a selector config.
:param nodes: Dictionary mapping dbt nodes (node.unique_id to node)
:param config: User-defined select statements
References:
https://docs.getdbt.com/reference/node-selection/syntax
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""
if config.is_empty:
return set(nodes.keys())

selected_nodes = set()
visited_nodes = set()
def __init__(self, nodes: dict[str, DbtNode], config: SelectorConfig) -> None:
self.nodes = nodes
self.config = config

def select_nodes_ids_by_intersection(self) -> set[str]:
"""
Return a list of node ids which matches the configuration defined in config.
References:
https://docs.getdbt.com/reference/node-selection/syntax
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""
if self.config.is_empty:
return set(self.nodes.keys())

self.selected_nodes: set[str] = set()
self.visited_nodes: set[str] = set()

for node_id, node in self.nodes.items():
if self._should_include_node(node_id, node):
self.selected_nodes.add(node_id)

return self.selected_nodes

def should_include_node(node_id: str, node: DbtNode) -> bool:
def _should_include_node(self, node_id: str, node: DbtNode) -> bool:
"Checks if a single node should be included. Only runs once per node with caching."
if node_id in visited_nodes:
return node_id in selected_nodes
if node_id in self.visited_nodes:
return node_id in self.selected_nodes

visited_nodes.add(node_id)
self.visited_nodes.add(node_id)

if node.resource_type == DbtResourceType.TEST:
node.tags = getattr(nodes.get(node.depends_on[0]), "tags", [])
node.tags = getattr(self.nodes.get(node.depends_on[0]), "tags", [])

if config.tags:
if not (set(config.tags) <= set(node.tags)):
return False
if not self._is_tags_subset(node):
return False

node_config = {key: value for key, value in node.config.items() if key in SUPPORTED_CONFIG}
config_tags = config.config.get("tags")
if config_tags and config_tags not in node_config.get("tags", []):

if not self._is_config_subset(node_config):
return False

# Remove 'tags' as they've already been filtered for
config_copy = copy.deepcopy(config.config)
config_copy = copy.deepcopy(self.config.config)
config_copy.pop("tags", None)
node_config.pop("tags", None)

if not (config_copy.items() <= node_config.items()):
return False

if config.paths:
for filter_path in config.paths:
if filter_path in node.file_path.parents or filter_path == node.file_path:
return True
if self.config.paths and not self._is_path_matching(node):
return False

# if it's a test coming from a schema.yml file, check the model's file_path
if node.resource_type == DbtResourceType.TEST and node.file_path.name == "schema.yml":
# try to get the corresponding model from node.depends_on
if len(node.depends_on) == 1:
model_node = nodes.get(node.depends_on[0])
if model_node:
return should_include_node(node.depends_on[0], model_node)
return True

def _is_tags_subset(self, node: DbtNode) -> bool:
"""Checks if the node's tags are a subset of the config's tags."""
if not (set(self.config.tags) <= set(node.tags)):
return False
return True

def _is_config_subset(self, node_config: dict[str, Any]) -> bool:
"""Checks if the node's config is a subset of the config's config."""
config_tags = self.config.config.get("tags")
if config_tags and config_tags not in node_config.get("tags", []):
return False
return True

for node_id, node in nodes.items():
if should_include_node(node_id, node):
selected_nodes.add(node_id)
def _is_path_matching(self, node: DbtNode) -> bool:
"""Checks if the node's path is a subset of the config's paths."""
for filter_path in self.config.paths:
if filter_path in node.file_path.parents or filter_path == node.file_path:
return True

return selected_nodes
# if it's a test coming from a schema.yml file, check the model's file_path
if node.resource_type == DbtResourceType.TEST and node.file_path.name == "schema.yml":
# try to get the corresponding model from node.depends_on
if len(node.depends_on) == 1:
model_node = self.nodes.get(node.depends_on[0])
if model_node:
return self._should_include_node(node.depends_on[0], model_node)
return False


def retrieve_by_label(statement_list: list[str], label: str) -> set[str]:
Expand Down Expand Up @@ -204,7 +228,8 @@ def select_nodes(

for statement in select:
config = SelectorConfig(project_dir, statement)
select_ids = select_nodes_ids_by_intersection(nodes, config)
node_selector = NodeSelector(nodes, config)
select_ids = node_selector.select_nodes_ids_by_intersection()
subset_ids = subset_ids.union(set(select_ids))

if select:
Expand All @@ -215,7 +240,8 @@ def select_nodes(
exclude_ids: set[str] = set()
for statement in exclude:
config = SelectorConfig(project_dir, statement)
exclude_ids = exclude_ids.union(set(select_nodes_ids_by_intersection(nodes, config)))
node_selector = NodeSelector(nodes, config)
exclude_ids = exclude_ids.union(set(node_selector.select_nodes_ids_by_intersection()))
subset_ids = set(nodes_ids) - set(exclude_ids)

return {id_: nodes[id_] for id_ in subset_ids}

0 comments on commit ceb1dac

Please sign in to comment.