diff --git a/.changes/unreleased/Under the Hood-20231102-161245.yaml b/.changes/unreleased/Under the Hood-20231102-161245.yaml new file mode 100644 index 0000000000..72f9286d85 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20231102-161245.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Add the ability to use distinct select in sql nodes +time: 2023-11-02T16:12:45.123252-04:00 +custom: + Author: WilliamDee + Issue: None diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index f6a3d14738..f7e2778d92 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -107,15 +107,19 @@ def _prune_columns_from_grandparents( order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D # Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't # need them. However, keep columns that are in group bys because that changes the meaning of the query. + # Similarly, if this node is a distinct select node, keep all columns as it may return a different result set. pruned_select_columns = tuple( select_column for select_column in node.select_columns - if select_column.column_alias in self._required_column_aliases or select_column in node.group_bys + if select_column.column_alias in self._required_column_aliases + or select_column in node.group_bys + or node.distinct ) if len(pruned_select_columns) == 0: @@ -183,6 +187,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D diff --git a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py index 0fcab8817f..c1be14554c 100644 --- a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py +++ b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py @@ -108,6 +108,7 @@ def _reduce_parents( order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) @staticmethod @@ -220,6 +221,10 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool: # if SqlRewritingSubQueryReducerVisitor._statement_contains_difficult_expressions(node): return False + # Don't reduce distinct selects + if parent_select_node.distinct: + return False + # Skip this case for simplicity of reasoning. if len(node.order_bys) > 0 and len(parent_select_node.order_bys) > 0: return False @@ -523,6 +528,7 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN order_bys=tuple(clauses_to_rewrite.order_bys), where=clauses_to_rewrite.combine_wheres(additional_where_clauses), limit=node.limit, + distinct=node.distinct, ) def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D @@ -641,6 +647,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP parent_node_where=parent_select_node.where, ), limit=new_limit, + distinct=parent_select_node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D @@ -698,6 +705,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D diff --git a/metricflow/sql/optimizer/sub_query_reducer.py b/metricflow/sql/optimizer/sub_query_reducer.py index 9de72cecfa..c3ef22a6ff 100644 --- a/metricflow/sql/optimizer/sub_query_reducer.py +++ b/metricflow/sql/optimizer/sub_query_reducer.py @@ -44,6 +44,7 @@ def _reduce_parents( order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def _reduce_is_possible(self, node: SqlSelectStatementNode) -> bool: # noqa: D @@ -71,6 +72,10 @@ def _reduce_is_possible(self, node: SqlSelectStatementNode) -> bool: # noqa: D # More conditions where we don't want to collapse. It's not impossible with these cases, but not reducing in # these cases for simplicity. + # Don't reduce distinct selects + if parent_select_node.distinct: + return False + # Reducing a where is tricky as it requires the expressions to be re-written. if node.where: return False @@ -179,6 +184,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP order_bys=tuple(new_order_by), where=parent_select_node.where, limit=new_limit, + distinct=parent_select_node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D diff --git a/metricflow/sql/optimizer/table_alias_simplifier.py b/metricflow/sql/optimizer/table_alias_simplifier.py index 0cbe7f8544..19a71e1936 100644 --- a/metricflow/sql/optimizer/table_alias_simplifier.py +++ b/metricflow/sql/optimizer/table_alias_simplifier.py @@ -44,6 +44,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP ), where=node.where.rewrite(should_render_table_alias=False) if node.where else None, limit=node.limit, + distinct=node.distinct, ) return SqlSelectStatementNode( @@ -64,6 +65,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D diff --git a/metricflow/sql/render/sql_plan_renderer.py b/metricflow/sql/render/sql_plan_renderer.py index 2fceea6edb..a502284d77 100644 --- a/metricflow/sql/render/sql_plan_renderer.py +++ b/metricflow/sql/render/sql_plan_renderer.py @@ -72,6 +72,7 @@ def _render_select_columns_section( self, select_columns: Sequence[SqlSelectColumn], num_parents: int, + distinct: bool, ) -> Tuple[str, SqlBindParameters]: """Convert the select columns into a "SELECT" section. @@ -84,7 +85,7 @@ def _render_select_columns_section( Returns a tuple of the "SELECT" section as a string and the associated execution parameters. """ params = SqlBindParameters() - select_section_lines = ["SELECT"] + select_section_lines = ["SELECT DISTINCT" if distinct else "SELECT"] first_column = True for select_column in select_columns: expr_rendered = self.EXPR_RENDERER.render_sql_expr(select_column.expr) @@ -217,7 +218,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRe description_section = "\n".join([f"-- {x}" for x in node.description.split("\n")]) # Render "SELECT" column section - select_section, select_params = self._render_select_columns_section(node.select_columns, len(node.parent_nodes)) + select_section, select_params = self._render_select_columns_section( + node.select_columns, len(node.parent_nodes), node.distinct + ) combined_params = combined_params.combine(select_params) # Render "FROM" section diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index 2807800dad..a764f93094 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -134,6 +134,7 @@ def __init__( # noqa: D order_bys: Tuple[SqlOrderByDescription, ...], where: Optional[SqlExpressionNode] = None, limit: Optional[int] = None, + distinct: bool = False, ) -> None: self._description = description assert select_columns @@ -145,6 +146,7 @@ def __init__( # noqa: D self._group_bys = group_bys self._where = where self._order_bys = order_bys + self._distinct = distinct if limit is not None: assert limit >= 0 @@ -173,6 +175,7 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D + [DisplayedProperty(f"group_by{i}", group_by) for i, group_by in enumerate(self._group_bys)] + [DisplayedProperty("where", self._where)] + [DisplayedProperty(f"order_by{i}", order_by) for i, order_by in enumerate(self._order_bys)] + + [DisplayedProperty("distinct", self._distinct)] ) @property @@ -218,6 +221,10 @@ def limit(self) -> Optional[int]: # noqa: D def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D return self + @property + def distinct(self) -> bool: # noqa: D + return self._distinct + class SqlTableFromClauseNode(SqlQueryPlanNode): """An SQL table that can go in the FROM clause.""" diff --git a/metricflow/test/snapshots/test_column_pruner.py/SqlQueryPlan/test_prune_distinct_select__after_pruning.sql b/metricflow/test/snapshots/test_column_pruner.py/SqlQueryPlan/test_prune_distinct_select__after_pruning.sql new file mode 100644 index 0000000000..9acdadd458 --- /dev/null +++ b/metricflow/test/snapshots/test_column_pruner.py/SqlQueryPlan/test_prune_distinct_select__after_pruning.sql @@ -0,0 +1,10 @@ +-- test0 +SELECT + a.booking_value +FROM ( + -- test1 + SELECT DISTINCT + a.booking_value + , a.bookings + FROM demo.fct_bookings a +) b diff --git a/metricflow/test/snapshots/test_column_pruner.py/SqlQueryPlan/test_prune_distinct_select__before_pruning.sql b/metricflow/test/snapshots/test_column_pruner.py/SqlQueryPlan/test_prune_distinct_select__before_pruning.sql new file mode 100644 index 0000000000..9acdadd458 --- /dev/null +++ b/metricflow/test/snapshots/test_column_pruner.py/SqlQueryPlan/test_prune_distinct_select__before_pruning.sql @@ -0,0 +1,10 @@ +-- test0 +SELECT + a.booking_value +FROM ( + -- test1 + SELECT DISTINCT + a.booking_value + , a.bookings + FROM demo.fct_bookings a +) b diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node__plan0.xml index 76ef510cce..77ef9e9b24 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node__plan0.xml @@ -16,6 +16,7 @@ + @@ -41,6 +42,7 @@ + @@ -64,6 +66,7 @@ + @@ -79,6 +82,7 @@ + @@ -436,6 +440,7 @@ + @@ -458,6 +463,7 @@ + @@ -687,6 +693,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_multiple_semantic_models__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_multiple_semantic_models__plan0.xml index bd7bff1c0a..a20f2ab36a 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_multiple_semantic_models__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_multiple_semantic_models__plan0.xml @@ -16,6 +16,7 @@ + @@ -43,6 +44,7 @@ + @@ -60,6 +62,7 @@ + @@ -85,6 +88,7 @@ + @@ -104,6 +108,7 @@ + @@ -131,6 +136,7 @@ + @@ -150,6 +156,7 @@ + @@ -547,6 +554,7 @@ + @@ -904,6 +912,7 @@ + @@ -927,6 +936,7 @@ + @@ -1200,6 +1210,7 @@ + @@ -1429,6 +1440,7 @@ + @@ -1458,6 +1470,7 @@ + @@ -1483,6 +1496,7 @@ + @@ -1502,6 +1516,7 @@ + @@ -1529,6 +1544,7 @@ + @@ -1548,6 +1564,7 @@ + @@ -1793,6 +1810,7 @@ + @@ -1994,6 +2012,7 @@ + @@ -2017,6 +2036,7 @@ + @@ -2290,6 +2310,7 @@ + @@ -2519,6 +2540,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_single_semantic_model__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_single_semantic_model__plan0.xml index 673a308906..b8ea0e4a87 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_single_semantic_model__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_ratio_from_single_semantic_model__plan0.xml @@ -16,6 +16,7 @@ + @@ -45,6 +46,7 @@ + @@ -72,6 +74,7 @@ + @@ -91,6 +94,7 @@ + @@ -448,6 +452,7 @@ + @@ -470,6 +475,7 @@ + @@ -699,6 +705,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_simple_expr__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_simple_expr__plan0.xml index 26438e6ad2..d54260dadb 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_simple_expr__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_compute_metrics_node_simple_expr__plan0.xml @@ -16,6 +16,7 @@ + @@ -41,6 +42,7 @@ + @@ -64,6 +66,7 @@ + @@ -79,6 +82,7 @@ + @@ -436,6 +440,7 @@ + @@ -458,6 +463,7 @@ + @@ -687,6 +693,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_constrain_time_range_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_constrain_time_range_node__plan0.xml index f502fba68f..2b515cea56 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_constrain_time_range_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_constrain_time_range_node__plan0.xml @@ -17,6 +17,7 @@ + @@ -34,6 +35,7 @@ + @@ -49,6 +51,7 @@ + @@ -406,6 +409,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_filter_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_filter_node__plan0.xml index 507d162fbc..43d205cc89 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_filter_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_filter_node__plan0.xml @@ -10,6 +10,7 @@ + @@ -367,6 +368,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_filter_with_where_constraint_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_filter_with_where_constraint_node__plan0.xml index 4ca58fa593..031d207292 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_filter_with_where_constraint_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_filter_with_where_constraint_node__plan0.xml @@ -13,6 +13,7 @@ + @@ -28,6 +29,7 @@ + @@ -385,6 +387,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_time_spine_node_with_offset_to_grain__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_time_spine_node_with_offset_to_grain__plan0.xml index da4b9f2c53..f15986b92a 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_time_spine_node_with_offset_to_grain__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_time_spine_node_with_offset_to_grain__plan0.xml @@ -22,6 +22,7 @@ + @@ -31,6 +32,7 @@ + @@ -54,6 +56,7 @@ + @@ -79,6 +82,7 @@ + @@ -98,6 +102,7 @@ + @@ -495,6 +500,7 @@ + @@ -852,6 +858,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_time_spine_node_with_offset_window__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_time_spine_node_with_offset_window__plan0.xml index da4b9f2c53..f15986b92a 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_time_spine_node_with_offset_window__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_time_spine_node_with_offset_window__plan0.xml @@ -22,6 +22,7 @@ + @@ -31,6 +32,7 @@ + @@ -54,6 +56,7 @@ + @@ -79,6 +82,7 @@ + @@ -98,6 +102,7 @@ + @@ -495,6 +500,7 @@ + @@ -852,6 +858,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_time_spine_node_without_offset__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_time_spine_node_without_offset__plan0.xml index da4b9f2c53..f15986b92a 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_time_spine_node_without_offset__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_join_to_time_spine_node_without_offset__plan0.xml @@ -22,6 +22,7 @@ + @@ -31,6 +32,7 @@ + @@ -54,6 +56,7 @@ + @@ -79,6 +82,7 @@ + @@ -98,6 +102,7 @@ + @@ -495,6 +500,7 @@ + @@ -852,6 +858,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_aggregation_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_aggregation_node__plan0.xml index 9040ebdb0c..4d73ee2270 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_aggregation_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_measure_aggregation_node__plan0.xml @@ -20,6 +20,7 @@ + @@ -43,6 +44,7 @@ + @@ -400,6 +402,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multi_join_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multi_join_node__plan0.xml index 86d176f9cb..3268ab93ba 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multi_join_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_multi_join_node__plan0.xml @@ -32,6 +32,7 @@ + @@ -47,6 +48,7 @@ + @@ -404,6 +406,7 @@ + @@ -426,6 +429,7 @@ + @@ -655,6 +659,7 @@ + @@ -677,6 +682,7 @@ + @@ -906,6 +912,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_order_by_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_order_by_node__plan0.xml index 402e2ef39d..76e0ba86fd 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_order_by_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_order_by_node__plan0.xml @@ -24,6 +24,7 @@ + @@ -41,6 +42,7 @@ + @@ -66,6 +68,7 @@ + @@ -85,6 +88,7 @@ + @@ -442,6 +446,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node__plan0.xml index 3b31be5065..187f6f44d1 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node__plan0.xml @@ -126,6 +126,7 @@ + @@ -247,6 +248,7 @@ + @@ -262,6 +264,7 @@ + @@ -383,6 +386,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_grouping__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_grouping__plan0.xml index c4d1d06282..364d97e86d 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_grouping__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_grouping__plan0.xml @@ -126,6 +126,7 @@ + @@ -247,6 +248,7 @@ + @@ -270,6 +272,7 @@ + @@ -391,6 +394,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_queried_group_by__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_queried_group_by__plan0.xml index c65adb8572..c877290f10 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_queried_group_by__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_semi_additive_join_node_with_queried_group_by__plan0.xml @@ -126,6 +126,7 @@ + @@ -247,6 +248,7 @@ + @@ -270,6 +272,7 @@ + @@ -391,6 +394,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_single_join_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_single_join_node__plan0.xml index 48afce360f..ab795d1fad 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_single_join_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_single_join_node__plan0.xml @@ -18,6 +18,7 @@ + @@ -33,6 +34,7 @@ + @@ -390,6 +392,7 @@ + @@ -412,6 +415,7 @@ + @@ -641,6 +645,7 @@ + diff --git a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_source_node__plan0.xml b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_source_node__plan0.xml index 89948cc666..13fda71ede 100644 --- a/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_source_node__plan0.xml +++ b/metricflow/test/snapshots/test_dataflow_to_sql_plan.py/SqlQueryPlan/test_source_node__plan0.xml @@ -356,6 +356,7 @@ + diff --git a/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_metric_time_dimension_transform_node_using_non_primary_time__plan0.xml b/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_metric_time_dimension_transform_node_using_non_primary_time__plan0.xml index 2e1a787a23..18da8dd77c 100644 --- a/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_metric_time_dimension_transform_node_using_non_primary_time__plan0.xml +++ b/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_metric_time_dimension_transform_node_using_non_primary_time__plan0.xml @@ -348,6 +348,7 @@ + @@ -705,6 +706,7 @@ + diff --git a/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_metric_time_dimension_transform_node_using_primary_time__plan0.xml b/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_metric_time_dimension_transform_node_using_primary_time__plan0.xml index db9f23a206..74cd78ad59 100644 --- a/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_metric_time_dimension_transform_node_using_primary_time__plan0.xml +++ b/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_metric_time_dimension_transform_node_using_primary_time__plan0.xml @@ -396,6 +396,7 @@ + @@ -753,6 +754,7 @@ + diff --git a/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_simple_query_with_metric_time_dimension__plan0.xml b/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_simple_query_with_metric_time_dimension__plan0.xml index e8cd14dee9..3ade70ad35 100644 --- a/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_simple_query_with_metric_time_dimension__plan0.xml +++ b/metricflow/test/snapshots/test_metric_time_dimension_to_sql.py/SqlQueryPlan/test_simple_query_with_metric_time_dimension__plan0.xml @@ -26,6 +26,7 @@ + @@ -39,6 +40,7 @@ + @@ -56,6 +58,7 @@ + @@ -71,6 +74,7 @@ + @@ -468,6 +472,7 @@ + @@ -825,6 +830,7 @@ + @@ -848,6 +854,7 @@ + @@ -865,6 +872,7 @@ + @@ -880,6 +888,7 @@ + @@ -1229,6 +1238,7 @@ + @@ -1586,6 +1596,7 @@ + diff --git a/metricflow/test/snapshots/test_rewriting_sub_query_reducer.py/SqlQueryPlan/test_rewriting_distinct_select_node_is_not_reduced__after_reducing.sql b/metricflow/test/snapshots/test_rewriting_sub_query_reducer.py/SqlQueryPlan/test_rewriting_distinct_select_node_is_not_reduced__after_reducing.sql new file mode 100644 index 0000000000..9acdadd458 --- /dev/null +++ b/metricflow/test/snapshots/test_rewriting_sub_query_reducer.py/SqlQueryPlan/test_rewriting_distinct_select_node_is_not_reduced__after_reducing.sql @@ -0,0 +1,10 @@ +-- test0 +SELECT + a.booking_value +FROM ( + -- test1 + SELECT DISTINCT + a.booking_value + , a.bookings + FROM demo.fct_bookings a +) b diff --git a/metricflow/test/snapshots/test_rewriting_sub_query_reducer.py/SqlQueryPlan/test_rewriting_distinct_select_node_is_not_reduced__before_reducing.sql b/metricflow/test/snapshots/test_rewriting_sub_query_reducer.py/SqlQueryPlan/test_rewriting_distinct_select_node_is_not_reduced__before_reducing.sql new file mode 100644 index 0000000000..9acdadd458 --- /dev/null +++ b/metricflow/test/snapshots/test_rewriting_sub_query_reducer.py/SqlQueryPlan/test_rewriting_distinct_select_node_is_not_reduced__before_reducing.sql @@ -0,0 +1,10 @@ +-- test0 +SELECT + a.booking_value +FROM ( + -- test1 + SELECT DISTINCT + a.booking_value + , a.bookings + FROM demo.fct_bookings a +) b diff --git a/metricflow/test/snapshots/test_sub_query_reducer.py/SqlQueryPlan/test_distinct_select_node_is_not_reduced__after_reducing.sql b/metricflow/test/snapshots/test_sub_query_reducer.py/SqlQueryPlan/test_distinct_select_node_is_not_reduced__after_reducing.sql new file mode 100644 index 0000000000..9acdadd458 --- /dev/null +++ b/metricflow/test/snapshots/test_sub_query_reducer.py/SqlQueryPlan/test_distinct_select_node_is_not_reduced__after_reducing.sql @@ -0,0 +1,10 @@ +-- test0 +SELECT + a.booking_value +FROM ( + -- test1 + SELECT DISTINCT + a.booking_value + , a.bookings + FROM demo.fct_bookings a +) b diff --git a/metricflow/test/snapshots/test_sub_query_reducer.py/SqlQueryPlan/test_distinct_select_node_is_not_reduced__before_reducing.sql b/metricflow/test/snapshots/test_sub_query_reducer.py/SqlQueryPlan/test_distinct_select_node_is_not_reduced__before_reducing.sql new file mode 100644 index 0000000000..9acdadd458 --- /dev/null +++ b/metricflow/test/snapshots/test_sub_query_reducer.py/SqlQueryPlan/test_distinct_select_node_is_not_reduced__before_reducing.sql @@ -0,0 +1,10 @@ +-- test0 +SELECT + a.booking_value +FROM ( + -- test1 + SELECT DISTINCT + a.booking_value + , a.bookings + FROM demo.fct_bookings a +) b diff --git a/metricflow/test/sql/optimizer/test_column_pruner.py b/metricflow/test/sql/optimizer/test_column_pruner.py index 9786734c6e..7c10ac7743 100644 --- a/metricflow/test/sql/optimizer/test_column_pruner.py +++ b/metricflow/test/sql/optimizer/test_column_pruner.py @@ -895,3 +895,65 @@ def test_prune_grandparents_in_join_query( sql_plan_node=column_pruned_select_node, plan_id="after_pruning", ) + + +def test_prune_distinct_select( + request: FixtureRequest, + mf_test_session_state: MetricFlowTestSessionState, + column_pruner: SqlColumnPrunerOptimizer, +) -> None: + """Test that distinct select node shouldn't be pruned.""" + select_node = SqlSelectStatementNode( + description="test0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") + ), + column_alias="booking_value", + ), + ), + from_source=SqlSelectStatementNode( + description="test1", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") + ), + column_alias="booking_value", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="bookings") + ), + column_alias="bookings", + ), + ), + from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source_alias="a", + joins_descs=(), + where=None, + group_bys=(), + order_bys=(), + distinct=True, + ), + from_source_alias="b", + joins_descs=(), + where=None, + group_bys=(), + order_bys=(), + ) + assert_default_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=select_node, + plan_id="before_pruning", + ) + + column_pruner.optimize(select_node) + assert_default_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=select_node, + plan_id="after_pruning", + ) diff --git a/metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py b/metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py index e05c044725..3eaabcfcd9 100644 --- a/metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py +++ b/metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py @@ -1056,3 +1056,65 @@ def test_reducing_join_left_node_statement( sql_plan_node=sub_query_reducer.optimize(reducing_join_left_node_statement), plan_id="after_reducing", ) + + +def test_rewriting_distinct_select_node_is_not_reduced( + request: FixtureRequest, + mf_test_session_state: MetricFlowTestSessionState, +) -> None: + """Tests to ensure distinct select node doesn't get overwritten.""" + select_node = SqlSelectStatementNode( + description="test0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") + ), + column_alias="booking_value", + ), + ), + from_source=SqlSelectStatementNode( + description="test1", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") + ), + column_alias="booking_value", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="bookings") + ), + column_alias="bookings", + ), + ), + from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source_alias="a", + joins_descs=(), + where=None, + group_bys=(), + order_bys=(), + distinct=True, + ), + from_source_alias="b", + joins_descs=(), + where=None, + group_bys=(), + order_bys=(), + ) + assert_default_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=select_node, + plan_id="before_reducing", + ) + + sub_query_reducer = SqlRewritingSubQueryReducer() + + assert_default_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=sub_query_reducer.optimize(select_node), + plan_id="after_reducing", + ) diff --git a/metricflow/test/sql/optimizer/test_sub_query_reducer.py b/metricflow/test/sql/optimizer/test_sub_query_reducer.py index 5c3b42388f..39041de30b 100644 --- a/metricflow/test/sql/optimizer/test_sub_query_reducer.py +++ b/metricflow/test/sql/optimizer/test_sub_query_reducer.py @@ -266,3 +266,65 @@ def test_rewrite_order_by_with_a_join_in_parent( sql_plan_node=sub_query_reducer.optimize(rewrite_order_by_statement), plan_id="after_reducing", ) + + +def test_distinct_select_node_is_not_reduced( + request: FixtureRequest, + mf_test_session_state: MetricFlowTestSessionState, +) -> None: + """Tests to ensure distinct select node doesn't get overwritten.""" + select_node = SqlSelectStatementNode( + description="test0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") + ), + column_alias="booking_value", + ), + ), + from_source=SqlSelectStatementNode( + description="test1", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") + ), + column_alias="booking_value", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="bookings") + ), + column_alias="bookings", + ), + ), + from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source_alias="a", + joins_descs=(), + where=None, + group_bys=(), + order_bys=(), + distinct=True, + ), + from_source_alias="b", + joins_descs=(), + where=None, + group_bys=(), + order_bys=(), + ) + assert_default_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=select_node, + plan_id="before_reducing", + ) + + sub_query_reducer = SqlSubQueryReducer() + + assert_default_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=sub_query_reducer.optimize(select_node), + plan_id="after_reducing", + )