From 66377f588a2b8d11b383b0f684edffef4ca36653 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Thu, 12 Dec 2024 11:49:10 -0800 Subject: [PATCH] Add improved test for the SQL optimization level in the request (#1566) This PR adds an improved test to show that the SQL optimization level in the request is handled correctly in the MF engine API. The previous test made a more general check, but a snapshot helps to check for specific features. Considering adding an assertion for the snapshot as well. --- .../source_scan/cm_branch_combiner.py | 2 +- .../source_scan/source_scan_optimizer.py | 9 +- .../sql/optimizer/optimization_levels.py | 8 ++ tests_metricflow/engine/test_explain.py | 46 +++++++ .../integration/test_mf_engine.py | 25 ---- tests_metricflow/snapshot_utils.py | 4 + .../str/test_optimization_level__result.txt | 125 ++++++++++++++++++ 7 files changed, 192 insertions(+), 27 deletions(-) create mode 100644 tests_metricflow/snapshots/test_explain.py/str/test_optimization_level__result.txt diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index 233629e7a6..3209e34b8b 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -133,7 +133,7 @@ def __init__(self, left_branch_node: DataflowPlanNode) -> None: # noqa: D107 self._current_left_node: DataflowPlanNode = left_branch_node def _log_visit_node_type(self, node: DataflowPlanNode) -> None: - logger.debug(lambda: f"Visiting {node.node_id}") + logger.debug(LazyFormat(lambda: f"Visiting {node.node_id}")) def _log_combine_failure( self, diff --git a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py index c84035335e..95c0aeec32 100644 --- a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py +++ b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py @@ -280,7 +280,14 @@ def visit_combine_aggregated_outputs_node( # noqa: D102 for branch_combination_result in combination_results ] - logger.debug(lambda: f"Got {len(combined_parent_branches)} branches after combination") + logger.debug( + LazyFormat( + "Possible branches combined.", + count_of_branches_before_combination=len(optimized_parent_branches), + count_of_branches_after_combination=len(combined_parent_branches), + ) + ) + assert len(combined_parent_branches) > 0 # If we were able to reduce the parent branches of the CombineAggregatedOutputsNode into a single one, there's diff --git a/metricflow/sql/optimizer/optimization_levels.py b/metricflow/sql/optimizer/optimization_levels.py index 63cedafbcd..e2e39fd1de 100644 --- a/metricflow/sql/optimizer/optimization_levels.py +++ b/metricflow/sql/optimizer/optimization_levels.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools from dataclasses import dataclass from enum import Enum from typing import Tuple @@ -13,6 +14,7 @@ from metricflow.sql.optimizer.table_alias_simplifier import SqlTableAliasSimplifier +@functools.total_ordering class SqlQueryOptimizationLevel(Enum): """Defines the level of query optimization and the associated optimizers to apply.""" @@ -27,6 +29,12 @@ class SqlQueryOptimizationLevel(Enum): def default_level() -> SqlQueryOptimizationLevel: # noqa: D102 return SqlQueryOptimizationLevel.O5 + def __lt__(self, other: SqlQueryOptimizationLevel) -> bool: # noqa: D105 + if not isinstance(other, SqlQueryOptimizationLevel): + return NotImplemented + + return self.name < other.name + @dataclass(frozen=True) class SqlGenerationOptionSet: diff --git a/tests_metricflow/engine/test_explain.py b/tests_metricflow/engine/test_explain.py index a456d2430b..f23f419981 100644 --- a/tests_metricflow/engine/test_explain.py +++ b/tests_metricflow/engine/test_explain.py @@ -1,10 +1,20 @@ from __future__ import annotations +import logging from concurrent.futures import Future, ThreadPoolExecutor from typing import Mapping, Sequence +import pytest +from _pytest.fixtures import FixtureRequest +from metricflow_semantics.mf_logging.pretty_print import mf_pformat_dict +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration + from metricflow.engine.metricflow_engine import MetricFlowEngine, MetricFlowExplainResult, MetricFlowQueryRequest +from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel from tests_metricflow.fixtures.manifest_fixtures import MetricFlowEngineTestFixture, SemanticManifestSetup +from tests_metricflow.snapshot_utils import assert_str_snapshot_equal + +logger = logging.getLogger(__name__) def _explain_one_query(mf_engine: MetricFlowEngine) -> str: @@ -29,3 +39,39 @@ def test_concurrent_explain_consistency( results = [future.result() for future in futures] for result in results: assert result == results[0], "Expected only one unique result / results to be the same" + + +@pytest.mark.sql_engine_snapshot +@pytest.mark.duckdb_only +def test_optimization_level( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + mf_engine_test_fixture_mapping: Mapping[SemanticManifestSetup, MetricFlowEngineTestFixture], +) -> None: + """Tests that the results of explain reflect the SQL optimization level in the request.""" + mf_engine = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].metricflow_engine + + results = {} + for optimization_level in SqlQueryOptimizationLevel: + # Skip lower optimization levels as they are generally not used. + if optimization_level <= SqlQueryOptimizationLevel.O3: + continue + + explain_result: MetricFlowExplainResult = mf_engine.explain( + MetricFlowQueryRequest.create_with_random_request_id( + metric_names=("bookings", "views"), + group_by_names=("metric_time", "listing__country_latest"), + sql_optimization_level=optimization_level, + ) + ) + results[optimization_level.value] = explain_result.rendered_sql_without_descriptions.sql_query + + assert_str_snapshot_equal( + request=request, + mf_test_configuration=mf_test_configuration, + snapshot_id="result", + snapshot_str=mf_pformat_dict( + description=None, obj_dict=results, preserve_raw_strings=True, pad_items_with_newlines=True + ), + expectation_description=f"The result for {SqlQueryOptimizationLevel.O5} should be SQL uses a CTE.", + ) diff --git a/tests_metricflow/integration/test_mf_engine.py b/tests_metricflow/integration/test_mf_engine.py index b933835aa8..9aa1373701 100644 --- a/tests_metricflow/integration/test_mf_engine.py +++ b/tests_metricflow/integration/test_mf_engine.py @@ -3,8 +3,6 @@ from _pytest.fixtures import FixtureRequest from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration -from metricflow.engine.metricflow_engine import MetricFlowExplainResult, MetricFlowQueryRequest -from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel from tests_metricflow.integration.conftest import IntegrationTestHelpers from tests_metricflow.snapshot_utils import assert_object_snapshot_equal @@ -18,26 +16,3 @@ def test_list_dimensions( # noqa: D103 obj_id="result0", obj=sorted([dim.qualified_name for dim in it_helpers.mf_engine.list_dimensions()]), ) - - -def test_sql_optimization_level(it_helpers: IntegrationTestHelpers) -> None: - """Check that different SQL optimization levels produce different SQL.""" - assert ( - SqlQueryOptimizationLevel.default_level() != SqlQueryOptimizationLevel.O0 - ), "The default optimization level should be different from the lowest level." - explain_result_at_default_level: MetricFlowExplainResult = it_helpers.mf_engine.explain( - MetricFlowQueryRequest.create_with_random_request_id( - metric_names=("bookings",), - group_by_names=("metric_time",), - sql_optimization_level=SqlQueryOptimizationLevel.default_level(), - ) - ) - explain_result_at_level_0: MetricFlowExplainResult = it_helpers.mf_engine.explain( - MetricFlowQueryRequest.create_with_random_request_id( - metric_names=("bookings",), - group_by_names=("metric_time",), - sql_optimization_level=SqlQueryOptimizationLevel.O0, - ) - ) - - assert explain_result_at_default_level.rendered_sql.sql_query != explain_result_at_level_0.rendered_sql.sql_query diff --git a/tests_metricflow/snapshot_utils.py b/tests_metricflow/snapshot_utils.py index 4f0f674e5d..5237fa8c94 100644 --- a/tests_metricflow/snapshot_utils.py +++ b/tests_metricflow/snapshot_utils.py @@ -141,4 +141,8 @@ def assert_str_snapshot_equal( # type: ignore[misc] snapshot_file_extension=".txt", additional_sub_directories_for_snapshots=(sql_engine.value,) if sql_engine is not None else (), expectation_description=expectation_description, + incomparable_strings_replacement_function=make_schema_replacement_function( + system_schema=mf_test_configuration.mf_system_schema, + source_schema=mf_test_configuration.mf_source_schema, + ), ) diff --git a/tests_metricflow/snapshots/test_explain.py/str/test_optimization_level__result.txt b/tests_metricflow/snapshots/test_explain.py/str/test_optimization_level__result.txt new file mode 100644 index 0000000000..5b0b49b34b --- /dev/null +++ b/tests_metricflow/snapshots/test_explain.py/str/test_optimization_level__result.txt @@ -0,0 +1,125 @@ +test_name: test_optimization_level +test_filename: test_explain.py +docstring: + Tests that the results of explain reflect the SQL optimization level in the request. +expectation_description: + The result for SqlQueryOptimizationLevel.O5 should be SQL uses a CTE. +--- +O4: + SELECT + COALESCE(subq_8.metric_time__day, subq_17.metric_time__day) AS metric_time__day + , COALESCE(subq_8.listing__country_latest, subq_17.listing__country_latest) AS listing__country_latest + , MAX(subq_8.bookings) AS bookings + , MAX(subq_17.views) AS views + FROM ( + SELECT + subq_1.metric_time__day AS metric_time__day + , listings_latest_src_10000.country AS listing__country_latest + , SUM(subq_1.bookings) AS bookings + FROM ( + SELECT + DATE_TRUNC('day', ds) AS metric_time__day + , listing_id AS listing + , 1 AS bookings + FROM ***************************.fct_bookings bookings_source_src_10000 + ) subq_1 + LEFT OUTER JOIN + ***************************.dim_listings_latest listings_latest_src_10000 + ON + subq_1.listing = listings_latest_src_10000.listing_id + GROUP BY + subq_1.metric_time__day + , listings_latest_src_10000.country + ) subq_8 + FULL OUTER JOIN ( + SELECT + subq_10.metric_time__day AS metric_time__day + , listings_latest_src_10000.country AS listing__country_latest + , SUM(subq_10.views) AS views + FROM ( + SELECT + DATE_TRUNC('day', ds) AS metric_time__day + , listing_id AS listing + , 1 AS views + FROM ***************************.fct_views views_source_src_10000 + ) subq_10 + LEFT OUTER JOIN + ***************************.dim_listings_latest listings_latest_src_10000 + ON + subq_10.listing = listings_latest_src_10000.listing_id + GROUP BY + subq_10.metric_time__day + , listings_latest_src_10000.country + ) subq_17 + ON + ( + subq_8.listing__country_latest = subq_17.listing__country_latest + ) AND ( + subq_8.metric_time__day = subq_17.metric_time__day + ) + GROUP BY + COALESCE(subq_8.metric_time__day, subq_17.metric_time__day) + , COALESCE(subq_8.listing__country_latest, subq_17.listing__country_latest) + +O5: + WITH sma_10014_cte AS ( + SELECT + listing_id AS listing + , country AS country_latest + FROM ***************************.dim_listings_latest listings_latest_src_10000 + ) + + SELECT + COALESCE(subq_8.metric_time__day, subq_16.metric_time__day) AS metric_time__day + , COALESCE(subq_8.listing__country_latest, subq_16.listing__country_latest) AS listing__country_latest + , MAX(subq_8.bookings) AS bookings + , MAX(subq_16.views) AS views + FROM ( + SELECT + subq_1.metric_time__day AS metric_time__day + , sma_10014_cte.country_latest AS listing__country_latest + , SUM(subq_1.bookings) AS bookings + FROM ( + SELECT + DATE_TRUNC('day', ds) AS metric_time__day + , listing_id AS listing + , 1 AS bookings + FROM ***************************.fct_bookings bookings_source_src_10000 + ) subq_1 + LEFT OUTER JOIN + sma_10014_cte sma_10014_cte + ON + subq_1.listing = sma_10014_cte.listing + GROUP BY + subq_1.metric_time__day + , sma_10014_cte.country_latest + ) subq_8 + FULL OUTER JOIN ( + SELECT + subq_10.metric_time__day AS metric_time__day + , sma_10014_cte.country_latest AS listing__country_latest + , SUM(subq_10.views) AS views + FROM ( + SELECT + DATE_TRUNC('day', ds) AS metric_time__day + , listing_id AS listing + , 1 AS views + FROM ***************************.fct_views views_source_src_10000 + ) subq_10 + LEFT OUTER JOIN + sma_10014_cte sma_10014_cte + ON + subq_10.listing = sma_10014_cte.listing + GROUP BY + subq_10.metric_time__day + , sma_10014_cte.country_latest + ) subq_16 + ON + ( + subq_8.listing__country_latest = subq_16.listing__country_latest + ) AND ( + subq_8.metric_time__day = subq_16.metric_time__day + ) + GROUP BY + COALESCE(subq_8.metric_time__day, subq_16.metric_time__day) + , COALESCE(subq_8.listing__country_latest, subq_16.listing__country_latest)