Skip to content

Commit

Permalink
updated optimizers to handle distinct select
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamDee committed Nov 3, 2023
1 parent 923ffb3 commit a57e608
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 1 deletion.
7 changes: 6 additions & 1 deletion metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,19 @@ def _prune_columns_from_grandparents(
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D
# Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't
# need them. However, keep columns that are in group bys because that changes the meaning of the query.
# Similarly, if this node is a distinct select node, keep all columns as it may return a different result set.
pruned_select_columns = tuple(
select_column
for select_column in node.select_columns
if select_column.column_alias in self._required_column_aliases or select_column in node.group_bys
if select_column.column_alias in self._required_column_aliases
or select_column in node.group_bys
or node.distinct
)

if len(pruned_select_columns) == 0:
Expand Down Expand Up @@ -183,6 +187,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down
8 changes: 8 additions & 0 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _reduce_parents(
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

@staticmethod
Expand Down Expand Up @@ -220,6 +221,10 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool: #
if SqlRewritingSubQueryReducerVisitor._statement_contains_difficult_expressions(node):
return False

# Don't reduce distinct selects
if parent_select_node.distinct:
return False

# Skip this case for simplicity of reasoning.
if len(node.order_bys) > 0 and len(parent_select_node.order_bys) > 0:
return False
Expand Down Expand Up @@ -523,6 +528,7 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN
order_bys=tuple(clauses_to_rewrite.order_bys),
where=clauses_to_rewrite.combine_wheres(additional_where_clauses),
limit=node.limit,
distinct=node.distinct,
)

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D
Expand Down Expand Up @@ -641,6 +647,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
parent_node_where=parent_select_node.where,
),
limit=new_limit,
distinct=parent_select_node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down Expand Up @@ -698,6 +705,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down
6 changes: 6 additions & 0 deletions metricflow/sql/optimizer/sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _reduce_parents(
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def _reduce_is_possible(self, node: SqlSelectStatementNode) -> bool: # noqa: D
Expand Down Expand Up @@ -71,6 +72,10 @@ def _reduce_is_possible(self, node: SqlSelectStatementNode) -> bool: # noqa: D
# More conditions where we don't want to collapse. It's not impossible with these cases, but not reducing in
# these cases for simplicity.

# Don't reduce distinct selects
if parent_select_node.distinct:
return False

# Reducing a where is tricky as it requires the expressions to be re-written.
if node.where:
return False
Expand Down Expand Up @@ -179,6 +184,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
order_bys=tuple(new_order_by),
where=parent_select_node.where,
limit=new_limit,
distinct=parent_select_node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down
2 changes: 2 additions & 0 deletions metricflow/sql/optimizer/table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
),
where=node.where.rewrite(should_render_table_alias=False) if node.where else None,
limit=node.limit,
distinct=node.distinct,
)

return SqlSelectStatementNode(
Expand All @@ -64,6 +65,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down

0 comments on commit a57e608

Please sign in to comment.