Skip to content

Commit

Permalink
Add improved test for the SQL optimization level in the request (#1566)
Browse files Browse the repository at this point in the history
This PR adds an improved test to show that the SQL optimization level in
the request is handled correctly in the MF engine API. The previous test
made a more general check, but a snapshot helps to check for specific
features. Considering adding an assertion for the snapshot as well.
  • Loading branch information
plypaul authored Dec 12, 2024
1 parent 91d444a commit 66377f5
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self, left_branch_node: DataflowPlanNode) -> None: # noqa: D107
self._current_left_node: DataflowPlanNode = left_branch_node

def _log_visit_node_type(self, node: DataflowPlanNode) -> None:
logger.debug(lambda: f"Visiting {node.node_id}")
logger.debug(LazyFormat(lambda: f"Visiting {node.node_id}"))

def _log_combine_failure(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,14 @@ def visit_combine_aggregated_outputs_node( # noqa: D102
for branch_combination_result in combination_results
]

logger.debug(lambda: f"Got {len(combined_parent_branches)} branches after combination")
logger.debug(
LazyFormat(
"Possible branches combined.",
count_of_branches_before_combination=len(optimized_parent_branches),
count_of_branches_after_combination=len(combined_parent_branches),
)
)

assert len(combined_parent_branches) > 0

# If we were able to reduce the parent branches of the CombineAggregatedOutputsNode into a single one, there's
Expand Down
8 changes: 8 additions & 0 deletions metricflow/sql/optimizer/optimization_levels.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
from dataclasses import dataclass
from enum import Enum
from typing import Tuple
Expand All @@ -13,6 +14,7 @@
from metricflow.sql.optimizer.table_alias_simplifier import SqlTableAliasSimplifier


@functools.total_ordering
class SqlQueryOptimizationLevel(Enum):
"""Defines the level of query optimization and the associated optimizers to apply."""

Expand All @@ -27,6 +29,12 @@ class SqlQueryOptimizationLevel(Enum):
def default_level() -> SqlQueryOptimizationLevel: # noqa: D102
return SqlQueryOptimizationLevel.O5

def __lt__(self, other: SqlQueryOptimizationLevel) -> bool: # noqa: D105
if not isinstance(other, SqlQueryOptimizationLevel):
return NotImplemented

return self.name < other.name


@dataclass(frozen=True)
class SqlGenerationOptionSet:
Expand Down
46 changes: 46 additions & 0 deletions tests_metricflow/engine/test_explain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from __future__ import annotations

import logging
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Mapping, Sequence

import pytest
from _pytest.fixtures import FixtureRequest
from metricflow_semantics.mf_logging.pretty_print import mf_pformat_dict
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.engine.metricflow_engine import MetricFlowEngine, MetricFlowExplainResult, MetricFlowQueryRequest
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from tests_metricflow.fixtures.manifest_fixtures import MetricFlowEngineTestFixture, SemanticManifestSetup
from tests_metricflow.snapshot_utils import assert_str_snapshot_equal

logger = logging.getLogger(__name__)


def _explain_one_query(mf_engine: MetricFlowEngine) -> str:
Expand All @@ -29,3 +39,39 @@ def test_concurrent_explain_consistency(
results = [future.result() for future in futures]
for result in results:
assert result == results[0], "Expected only one unique result / results to be the same"


@pytest.mark.sql_engine_snapshot
@pytest.mark.duckdb_only
def test_optimization_level(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
mf_engine_test_fixture_mapping: Mapping[SemanticManifestSetup, MetricFlowEngineTestFixture],
) -> None:
"""Tests that the results of explain reflect the SQL optimization level in the request."""
mf_engine = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].metricflow_engine

results = {}
for optimization_level in SqlQueryOptimizationLevel:
# Skip lower optimization levels as they are generally not used.
if optimization_level <= SqlQueryOptimizationLevel.O3:
continue

explain_result: MetricFlowExplainResult = mf_engine.explain(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=("bookings", "views"),
group_by_names=("metric_time", "listing__country_latest"),
sql_optimization_level=optimization_level,
)
)
results[optimization_level.value] = explain_result.rendered_sql_without_descriptions.sql_query

assert_str_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
snapshot_id="result",
snapshot_str=mf_pformat_dict(
description=None, obj_dict=results, preserve_raw_strings=True, pad_items_with_newlines=True
),
expectation_description=f"The result for {SqlQueryOptimizationLevel.O5} should be SQL uses a CTE.",
)
25 changes: 0 additions & 25 deletions tests_metricflow/integration/test_mf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from _pytest.fixtures import FixtureRequest
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.engine.metricflow_engine import MetricFlowExplainResult, MetricFlowQueryRequest
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from tests_metricflow.integration.conftest import IntegrationTestHelpers
from tests_metricflow.snapshot_utils import assert_object_snapshot_equal

Expand All @@ -18,26 +16,3 @@ def test_list_dimensions( # noqa: D103
obj_id="result0",
obj=sorted([dim.qualified_name for dim in it_helpers.mf_engine.list_dimensions()]),
)


def test_sql_optimization_level(it_helpers: IntegrationTestHelpers) -> None:
"""Check that different SQL optimization levels produce different SQL."""
assert (
SqlQueryOptimizationLevel.default_level() != SqlQueryOptimizationLevel.O0
), "The default optimization level should be different from the lowest level."
explain_result_at_default_level: MetricFlowExplainResult = it_helpers.mf_engine.explain(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=("bookings",),
group_by_names=("metric_time",),
sql_optimization_level=SqlQueryOptimizationLevel.default_level(),
)
)
explain_result_at_level_0: MetricFlowExplainResult = it_helpers.mf_engine.explain(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=("bookings",),
group_by_names=("metric_time",),
sql_optimization_level=SqlQueryOptimizationLevel.O0,
)
)

assert explain_result_at_default_level.rendered_sql.sql_query != explain_result_at_level_0.rendered_sql.sql_query
4 changes: 4 additions & 0 deletions tests_metricflow/snapshot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,8 @@ def assert_str_snapshot_equal( # type: ignore[misc]
snapshot_file_extension=".txt",
additional_sub_directories_for_snapshots=(sql_engine.value,) if sql_engine is not None else (),
expectation_description=expectation_description,
incomparable_strings_replacement_function=make_schema_replacement_function(
system_schema=mf_test_configuration.mf_system_schema,
source_schema=mf_test_configuration.mf_source_schema,
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
test_name: test_optimization_level
test_filename: test_explain.py
docstring:
Tests that the results of explain reflect the SQL optimization level in the request.
expectation_description:
The result for SqlQueryOptimizationLevel.O5 should be SQL uses a CTE.
---
O4:
SELECT
COALESCE(subq_8.metric_time__day, subq_17.metric_time__day) AS metric_time__day
, COALESCE(subq_8.listing__country_latest, subq_17.listing__country_latest) AS listing__country_latest
, MAX(subq_8.bookings) AS bookings
, MAX(subq_17.views) AS views
FROM (
SELECT
subq_1.metric_time__day AS metric_time__day
, listings_latest_src_10000.country AS listing__country_latest
, SUM(subq_1.bookings) AS bookings
FROM (
SELECT
DATE_TRUNC('day', ds) AS metric_time__day
, listing_id AS listing
, 1 AS bookings
FROM ***************************.fct_bookings bookings_source_src_10000
) subq_1
LEFT OUTER JOIN
***************************.dim_listings_latest listings_latest_src_10000
ON
subq_1.listing = listings_latest_src_10000.listing_id
GROUP BY
subq_1.metric_time__day
, listings_latest_src_10000.country
) subq_8
FULL OUTER JOIN (
SELECT
subq_10.metric_time__day AS metric_time__day
, listings_latest_src_10000.country AS listing__country_latest
, SUM(subq_10.views) AS views
FROM (
SELECT
DATE_TRUNC('day', ds) AS metric_time__day
, listing_id AS listing
, 1 AS views
FROM ***************************.fct_views views_source_src_10000
) subq_10
LEFT OUTER JOIN
***************************.dim_listings_latest listings_latest_src_10000
ON
subq_10.listing = listings_latest_src_10000.listing_id
GROUP BY
subq_10.metric_time__day
, listings_latest_src_10000.country
) subq_17
ON
(
subq_8.listing__country_latest = subq_17.listing__country_latest
) AND (
subq_8.metric_time__day = subq_17.metric_time__day
)
GROUP BY
COALESCE(subq_8.metric_time__day, subq_17.metric_time__day)
, COALESCE(subq_8.listing__country_latest, subq_17.listing__country_latest)

O5:
WITH sma_10014_cte AS (
SELECT
listing_id AS listing
, country AS country_latest
FROM ***************************.dim_listings_latest listings_latest_src_10000
)

SELECT
COALESCE(subq_8.metric_time__day, subq_16.metric_time__day) AS metric_time__day
, COALESCE(subq_8.listing__country_latest, subq_16.listing__country_latest) AS listing__country_latest
, MAX(subq_8.bookings) AS bookings
, MAX(subq_16.views) AS views
FROM (
SELECT
subq_1.metric_time__day AS metric_time__day
, sma_10014_cte.country_latest AS listing__country_latest
, SUM(subq_1.bookings) AS bookings
FROM (
SELECT
DATE_TRUNC('day', ds) AS metric_time__day
, listing_id AS listing
, 1 AS bookings
FROM ***************************.fct_bookings bookings_source_src_10000
) subq_1
LEFT OUTER JOIN
sma_10014_cte sma_10014_cte
ON
subq_1.listing = sma_10014_cte.listing
GROUP BY
subq_1.metric_time__day
, sma_10014_cte.country_latest
) subq_8
FULL OUTER JOIN (
SELECT
subq_10.metric_time__day AS metric_time__day
, sma_10014_cte.country_latest AS listing__country_latest
, SUM(subq_10.views) AS views
FROM (
SELECT
DATE_TRUNC('day', ds) AS metric_time__day
, listing_id AS listing
, 1 AS views
FROM ***************************.fct_views views_source_src_10000
) subq_10
LEFT OUTER JOIN
sma_10014_cte sma_10014_cte
ON
subq_10.listing = sma_10014_cte.listing
GROUP BY
subq_10.metric_time__day
, sma_10014_cte.country_latest
) subq_16
ON
(
subq_8.listing__country_latest = subq_16.listing__country_latest
) AND (
subq_8.metric_time__day = subq_16.metric_time__day
)
GROUP BY
COALESCE(subq_8.metric_time__day, subq_16.metric_time__day)
, COALESCE(subq_8.listing__country_latest, subq_16.listing__country_latest)

0 comments on commit 66377f5

Please sign in to comment.