From 49833dc6e67b355372e19886591642b410b928b1 Mon Sep 17 00:00:00 2001 From: Will Deng Date: Thu, 2 Nov 2023 15:10:39 -0400 Subject: [PATCH] added tests --- .../test/sql/optimizer/test_column_pruner.py | 62 +++++++++++++++++++ .../test_rewriting_sub_query_reducer.py | 62 +++++++++++++++++++ .../sql/optimizer/test_sub_query_reducer.py | 62 +++++++++++++++++++ 3 files changed, 186 insertions(+) diff --git a/metricflow/test/sql/optimizer/test_column_pruner.py b/metricflow/test/sql/optimizer/test_column_pruner.py index 9786734c6e..7c10ac7743 100644 --- a/metricflow/test/sql/optimizer/test_column_pruner.py +++ b/metricflow/test/sql/optimizer/test_column_pruner.py @@ -895,3 +895,65 @@ def test_prune_grandparents_in_join_query( sql_plan_node=column_pruned_select_node, plan_id="after_pruning", ) + + +def test_prune_distinct_select( + request: FixtureRequest, + mf_test_session_state: MetricFlowTestSessionState, + column_pruner: SqlColumnPrunerOptimizer, +) -> None: + """Test that distinct select node shouldn't be pruned.""" + select_node = SqlSelectStatementNode( + description="test0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") + ), + column_alias="booking_value", + ), + ), + from_source=SqlSelectStatementNode( + description="test1", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") + ), + column_alias="booking_value", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="bookings") + ), + column_alias="bookings", + ), + ), + from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source_alias="a", + joins_descs=(), + where=None, + group_bys=(), + order_bys=(), + distinct=True, + ), + from_source_alias="b", + joins_descs=(), + where=None, + group_bys=(), + order_bys=(), + ) + assert_default_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=select_node, + plan_id="before_pruning", + ) + + column_pruner.optimize(select_node) + assert_default_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=select_node, + plan_id="after_pruning", + ) diff --git a/metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py b/metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py index e05c044725..3eaabcfcd9 100644 --- a/metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py +++ b/metricflow/test/sql/optimizer/test_rewriting_sub_query_reducer.py @@ -1056,3 +1056,65 @@ def test_reducing_join_left_node_statement( sql_plan_node=sub_query_reducer.optimize(reducing_join_left_node_statement), plan_id="after_reducing", ) + + +def test_rewriting_distinct_select_node_is_not_reduced( + request: FixtureRequest, + mf_test_session_state: MetricFlowTestSessionState, +) -> None: + """Tests to ensure distinct select node doesn't get overwritten.""" + select_node = SqlSelectStatementNode( + description="test0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") + ), + column_alias="booking_value", + ), + ), + from_source=SqlSelectStatementNode( + description="test1", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") + ), + column_alias="booking_value", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="bookings") + ), + column_alias="bookings", + ), + ), + from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source_alias="a", + joins_descs=(), + where=None, + group_bys=(), + order_bys=(), + distinct=True, + ), + from_source_alias="b", + joins_descs=(), + where=None, + group_bys=(), + order_bys=(), + ) + assert_default_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=select_node, + plan_id="before_reducing", + ) + + sub_query_reducer = SqlRewritingSubQueryReducer() + + assert_default_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=sub_query_reducer.optimize(select_node), + plan_id="after_reducing", + ) diff --git a/metricflow/test/sql/optimizer/test_sub_query_reducer.py b/metricflow/test/sql/optimizer/test_sub_query_reducer.py index 5c3b42388f..39041de30b 100644 --- a/metricflow/test/sql/optimizer/test_sub_query_reducer.py +++ b/metricflow/test/sql/optimizer/test_sub_query_reducer.py @@ -266,3 +266,65 @@ def test_rewrite_order_by_with_a_join_in_parent( sql_plan_node=sub_query_reducer.optimize(rewrite_order_by_statement), plan_id="after_reducing", ) + + +def test_distinct_select_node_is_not_reduced( + request: FixtureRequest, + mf_test_session_state: MetricFlowTestSessionState, +) -> None: + """Tests to ensure distinct select node doesn't get overwritten.""" + select_node = SqlSelectStatementNode( + description="test0", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") + ), + column_alias="booking_value", + ), + ), + from_source=SqlSelectStatementNode( + description="test1", + select_columns=( + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") + ), + column_alias="booking_value", + ), + SqlSelectColumn( + expr=SqlColumnReferenceExpression( + col_ref=SqlColumnReference(table_alias="a", column_name="bookings") + ), + column_alias="bookings", + ), + ), + from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source_alias="a", + joins_descs=(), + where=None, + group_bys=(), + order_bys=(), + distinct=True, + ), + from_source_alias="b", + joins_descs=(), + where=None, + group_bys=(), + order_bys=(), + ) + assert_default_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=select_node, + plan_id="before_reducing", + ) + + sub_query_reducer = SqlSubQueryReducer() + + assert_default_rendered_sql_equal( + request=request, + mf_test_session_state=mf_test_session_state, + sql_plan_node=sub_query_reducer.optimize(select_node), + plan_id="after_reducing", + )