diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index f6a3d14738..f7e2778d92 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -107,15 +107,19 @@ def _prune_columns_from_grandparents( order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D # Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't # need them. However, keep columns that are in group bys because that changes the meaning of the query. + # Similarly, if this node is a distinct select node, keep all columns as it may return a different result set. pruned_select_columns = tuple( select_column for select_column in node.select_columns - if select_column.column_alias in self._required_column_aliases or select_column in node.group_bys + if select_column.column_alias in self._required_column_aliases + or select_column in node.group_bys + or node.distinct ) if len(pruned_select_columns) == 0: @@ -183,6 +187,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D diff --git a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py index 0fcab8817f..c1be14554c 100644 --- a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py +++ b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py @@ -108,6 +108,7 @@ def _reduce_parents( order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) @staticmethod @@ -220,6 +221,10 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool: # if SqlRewritingSubQueryReducerVisitor._statement_contains_difficult_expressions(node): return False + # Don't reduce distinct selects + if parent_select_node.distinct: + return False + # Skip this case for simplicity of reasoning. if len(node.order_bys) > 0 and len(parent_select_node.order_bys) > 0: return False @@ -523,6 +528,7 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN order_bys=tuple(clauses_to_rewrite.order_bys), where=clauses_to_rewrite.combine_wheres(additional_where_clauses), limit=node.limit, + distinct=node.distinct, ) def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D @@ -641,6 +647,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP parent_node_where=parent_select_node.where, ), limit=new_limit, + distinct=parent_select_node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D @@ -698,6 +705,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D diff --git a/metricflow/sql/optimizer/sub_query_reducer.py b/metricflow/sql/optimizer/sub_query_reducer.py index 9de72cecfa..c3ef22a6ff 100644 --- a/metricflow/sql/optimizer/sub_query_reducer.py +++ b/metricflow/sql/optimizer/sub_query_reducer.py @@ -44,6 +44,7 @@ def _reduce_parents( order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def _reduce_is_possible(self, node: SqlSelectStatementNode) -> bool: # noqa: D @@ -71,6 +72,10 @@ def _reduce_is_possible(self, node: SqlSelectStatementNode) -> bool: # noqa: D # More conditions where we don't want to collapse. It's not impossible with these cases, but not reducing in # these cases for simplicity. + # Don't reduce distinct selects + if parent_select_node.distinct: + return False + # Reducing a where is tricky as it requires the expressions to be re-written. if node.where: return False @@ -179,6 +184,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP order_bys=tuple(new_order_by), where=parent_select_node.where, limit=new_limit, + distinct=parent_select_node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D diff --git a/metricflow/sql/optimizer/table_alias_simplifier.py b/metricflow/sql/optimizer/table_alias_simplifier.py index 0cbe7f8544..19a71e1936 100644 --- a/metricflow/sql/optimizer/table_alias_simplifier.py +++ b/metricflow/sql/optimizer/table_alias_simplifier.py @@ -44,6 +44,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP ), where=node.where.rewrite(should_render_table_alias=False) if node.where else None, limit=node.limit, + distinct=node.distinct, ) return SqlSelectStatementNode( @@ -64,6 +65,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D