Skip to content

Commit

Permalink
Add ability to use distinct select in nodes (#834)
Browse files Browse the repository at this point in the history
* added ability to select distinct

* updated optimizers to handle distinct select

* added tests

* changelog

* updated snapshots
  • Loading branch information
WilliamDee authored Nov 3, 2023
1 parent b0b4ccf commit 151975b
Show file tree
Hide file tree
Showing 37 changed files with 407 additions and 3 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20231102-161245.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Add the ability to use distinct select in sql nodes
time: 2023-11-02T16:12:45.123252-04:00
custom:
Author: WilliamDee
Issue: None
7 changes: 6 additions & 1 deletion metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,19 @@ def _prune_columns_from_grandparents(
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D
# Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't
# need them. However, keep columns that are in group bys because that changes the meaning of the query.
# Similarly, if this node is a distinct select node, keep all columns as it may return a different result set.
pruned_select_columns = tuple(
select_column
for select_column in node.select_columns
if select_column.column_alias in self._required_column_aliases or select_column in node.group_bys
if select_column.column_alias in self._required_column_aliases
or select_column in node.group_bys
or node.distinct
)

if len(pruned_select_columns) == 0:
Expand Down Expand Up @@ -183,6 +187,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down
8 changes: 8 additions & 0 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _reduce_parents(
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

@staticmethod
Expand Down Expand Up @@ -220,6 +221,10 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool: #
if SqlRewritingSubQueryReducerVisitor._statement_contains_difficult_expressions(node):
return False

# Don't reduce distinct selects
if parent_select_node.distinct:
return False

# Skip this case for simplicity of reasoning.
if len(node.order_bys) > 0 and len(parent_select_node.order_bys) > 0:
return False
Expand Down Expand Up @@ -523,6 +528,7 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN
order_bys=tuple(clauses_to_rewrite.order_bys),
where=clauses_to_rewrite.combine_wheres(additional_where_clauses),
limit=node.limit,
distinct=node.distinct,
)

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D
Expand Down Expand Up @@ -641,6 +647,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
parent_node_where=parent_select_node.where,
),
limit=new_limit,
distinct=parent_select_node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down Expand Up @@ -698,6 +705,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down
6 changes: 6 additions & 0 deletions metricflow/sql/optimizer/sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _reduce_parents(
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def _reduce_is_possible(self, node: SqlSelectStatementNode) -> bool: # noqa: D
Expand Down Expand Up @@ -71,6 +72,10 @@ def _reduce_is_possible(self, node: SqlSelectStatementNode) -> bool: # noqa: D
# More conditions where we don't want to collapse. It's not impossible with these cases, but not reducing in
# these cases for simplicity.

# Don't reduce distinct selects
if parent_select_node.distinct:
return False

# Reducing a where is tricky as it requires the expressions to be re-written.
if node.where:
return False
Expand Down Expand Up @@ -179,6 +184,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
order_bys=tuple(new_order_by),
where=parent_select_node.where,
limit=new_limit,
distinct=parent_select_node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down
2 changes: 2 additions & 0 deletions metricflow/sql/optimizer/table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
),
where=node.where.rewrite(should_render_table_alias=False) if node.where else None,
limit=node.limit,
distinct=node.distinct,
)

return SqlSelectStatementNode(
Expand All @@ -64,6 +65,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down
7 changes: 5 additions & 2 deletions metricflow/sql/render/sql_plan_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _render_select_columns_section(
self,
select_columns: Sequence[SqlSelectColumn],
num_parents: int,
distinct: bool,
) -> Tuple[str, SqlBindParameters]:
"""Convert the select columns into a "SELECT" section.
Expand All @@ -84,7 +85,7 @@ def _render_select_columns_section(
Returns a tuple of the "SELECT" section as a string and the associated execution parameters.
"""
params = SqlBindParameters()
select_section_lines = ["SELECT"]
select_section_lines = ["SELECT DISTINCT" if distinct else "SELECT"]
first_column = True
for select_column in select_columns:
expr_rendered = self.EXPR_RENDERER.render_sql_expr(select_column.expr)
Expand Down Expand Up @@ -217,7 +218,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRe
description_section = "\n".join([f"-- {x}" for x in node.description.split("\n")])

# Render "SELECT" column section
select_section, select_params = self._render_select_columns_section(node.select_columns, len(node.parent_nodes))
select_section, select_params = self._render_select_columns_section(
node.select_columns, len(node.parent_nodes), node.distinct
)
combined_params = combined_params.combine(select_params)

# Render "FROM" section
Expand Down
7 changes: 7 additions & 0 deletions metricflow/sql/sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__( # noqa: D
order_bys: Tuple[SqlOrderByDescription, ...],
where: Optional[SqlExpressionNode] = None,
limit: Optional[int] = None,
distinct: bool = False,
) -> None:
self._description = description
assert select_columns
Expand All @@ -145,6 +146,7 @@ def __init__( # noqa: D
self._group_bys = group_bys
self._where = where
self._order_bys = order_bys
self._distinct = distinct

if limit is not None:
assert limit >= 0
Expand Down Expand Up @@ -173,6 +175,7 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D
+ [DisplayedProperty(f"group_by{i}", group_by) for i, group_by in enumerate(self._group_bys)]
+ [DisplayedProperty("where", self._where)]
+ [DisplayedProperty(f"order_by{i}", order_by) for i, order_by in enumerate(self._order_bys)]
+ [DisplayedProperty("distinct", self._distinct)]
)

@property
Expand Down Expand Up @@ -218,6 +221,10 @@ def limit(self) -> Optional[int]: # noqa: D
def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D
return self

@property
def distinct(self) -> bool: # noqa: D
return self._distinct


class SqlTableFromClauseNode(SqlQueryPlanNode):
"""An SQL table that can go in the FROM clause."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-- test0
SELECT
a.booking_value
FROM (
-- test1
SELECT DISTINCT
a.booking_value
, a.bookings
FROM demo.fct_bookings a
) b
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-- test0
SELECT
a.booking_value
FROM (
-- test1
SELECT DISTINCT
a.booking_value
, a.bookings
FROM demo.fct_bookings a
) b
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<!-- 'column_alias': 'bookings'} -->
<!-- from_source = SqlSelectStatementNode(node_id=ss_3) -->
<!-- where = None -->
<!-- distinct = False -->
<SqlSelectStatementNode>
<!-- description = Aggregate Measures -->
<!-- node_id = ss_3 -->
Expand All @@ -41,6 +42,7 @@
<!-- 'expr': SqlColumnReferenceExpression(node_id=cr_10), -->
<!-- 'column_alias': 'listing__country_latest'} -->
<!-- where = None -->
<!-- distinct = False -->
<SqlSelectStatementNode>
<!-- description = Join Standard Outputs -->
<!-- node_id = ss_2 -->
Expand All @@ -64,6 +66,7 @@
<!-- 'join_type': SqlJoinType.LEFT_OUTER, -->
<!-- 'on_condition': SqlComparisonExpression(node_id=cmp_0)} -->
<!-- where = None -->
<!-- distinct = False -->
<SqlSelectStatementNode>
<!-- description = -->
<!-- Pass Only Elements: -->
Expand All @@ -79,6 +82,7 @@
<!-- 'column_alias': 'bookings'} -->
<!-- from_source = SqlSelectStatementNode(node_id=ss_10001) -->
<!-- where = None -->
<!-- distinct = False -->
<SqlSelectStatementNode>
<!-- description = Read Elements From Semantic Model 'bookings_source' -->
<!-- node_id = ss_10001 -->
Expand Down Expand Up @@ -436,6 +440,7 @@
<!-- 'column_alias': 'booking__host'} -->
<!-- from_source = SqlTableFromClauseNode(node_id=tfc_10001) -->
<!-- where = None -->
<!-- distinct = False -->
<SqlTableFromClauseNode>
<!-- description = Read from ***************************.fct_bookings -->
<!-- node_id = tfc_10001 -->
Expand All @@ -458,6 +463,7 @@
<!-- 'column_alias': 'country_latest'} -->
<!-- from_source = SqlSelectStatementNode(node_id=ss_10004) -->
<!-- where = None -->
<!-- distinct = False -->
<SqlSelectStatementNode>
<!-- description = Read Elements From Semantic Model 'listings_latest' -->
<!-- node_id = ss_10004 -->
Expand Down Expand Up @@ -687,6 +693,7 @@
<!-- 'column_alias': 'listing__user'} -->
<!-- from_source = SqlTableFromClauseNode(node_id=tfc_10004) -->
<!-- where = None -->
<!-- distinct = False -->
<SqlTableFromClauseNode>
<!-- description = Read from ***************************.dim_listings_latest -->
<!-- node_id = tfc_10004 -->
Expand Down
Loading

0 comments on commit 151975b

Please sign in to comment.