From 2e1cabbdc073ca2546b985af0c3d50d1b80c840d Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 11 Dec 2024 15:28:35 -0800 Subject: [PATCH 1/4] Move sql_exprs file into metricflow-semantics --- .../metricflow_semantics}/sql/sql_exprs.py | 0 metricflow/dataset/convert_semantic_model.py | 14 ++-- metricflow/plan_conversion/dataflow_to_sql.py | 44 ++++++------- .../plan_conversion/instance_converters.py | 10 +-- .../sql_expression_builders.py | 3 +- .../plan_conversion/sql_join_builder.py | 18 ++--- .../sql/optimizer/required_column_aliases.py | 2 +- .../optimizer/rewriting_sub_query_reducer.py | 8 +-- metricflow/sql/optimizer/sub_query_reducer.py | 2 +- metricflow/sql/render/big_query.py | 20 +++--- metricflow/sql/render/databricks.py | 2 +- metricflow/sql/render/duckdb_renderer.py | 14 ++-- metricflow/sql/render/expr_renderer.py | 8 +-- metricflow/sql/render/postgres.py | 14 ++-- metricflow/sql/render/redshift.py | 12 ++-- metricflow/sql/render/snowflake.py | 10 +-- metricflow/sql/render/sql_plan_renderer.py | 2 +- metricflow/sql/render/trino.py | 16 ++--- metricflow/sql/sql_plan.py | 3 +- .../dataflow/builder/test_node_data_set.py | 2 +- .../integration/test_configured_cases.py | 66 ++++++++++--------- .../mf_logging/test_dag_to_text.py | 6 +- ...select_columns_with_measures_aggregated.py | 10 +-- .../sql/optimizer/test_column_pruner.py | 14 ++-- .../sql/optimizer/test_cte_column_pruner.py | 12 ++-- .../test_cte_rewriting_sub_query_reducer.py | 12 ++-- .../test_cte_table_alias_simplifier.py | 12 ++-- .../test_rewriting_sub_query_reducer.py | 12 ++-- .../sql/optimizer/test_sub_query_reducer.py | 12 ++-- .../optimizer/test_table_alias_simplifier.py | 12 ++-- .../sql/test_engine_specific_rendering.py | 10 +-- tests_metricflow/sql/test_render_cte.py | 10 +-- tests_metricflow/sql/test_sql_expr_render.py | 8 +-- tests_metricflow/sql/test_sql_plan_render.py | 12 ++-- .../sql_clients/test_date_time_operations.py | 8 +-- 35 files changed, 211 insertions(+), 209 deletions(-) rename {metricflow => metricflow-semantics/metricflow_semantics}/sql/sql_exprs.py (100%) diff --git a/metricflow/sql/sql_exprs.py b/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py similarity index 100% rename from metricflow/sql/sql_exprs.py rename to metricflow-semantics/metricflow_semantics/sql/sql_exprs.py diff --git a/metricflow/dataset/convert_semantic_model.py b/metricflow/dataset/convert_semantic_model.py index 56ee4279e6..050b13edb6 100644 --- a/metricflow/dataset/convert_semantic_model.py +++ b/metricflow/dataset/convert_semantic_model.py @@ -32,13 +32,7 @@ from metricflow_semantics.specs.dimension_spec import DimensionSpec from metricflow_semantics.specs.entity_spec import EntitySpec from metricflow_semantics.specs.time_dimension_spec import DEFAULT_TIME_GRANULARITY, TimeDimensionSpec -from metricflow_semantics.sql.sql_table import SqlTable -from metricflow_semantics.time.granularity import ExpandedTimeGranularity -from metricflow_semantics.time.time_spine_source import TimeSpineSource - -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet -from metricflow.dataset.sql_dataset import SqlDataSet -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlColumnReference, SqlColumnReferenceExpression, SqlDateTruncExpression, @@ -46,6 +40,12 @@ SqlExtractExpression, SqlStringExpression, ) +from metricflow_semantics.sql.sql_table import SqlTable +from metricflow_semantics.time.granularity import ExpandedTimeGranularity +from metricflow_semantics.time.time_spine_source import TimeSpineSource + +from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet +from metricflow.dataset.sql_dataset import SqlDataSet from metricflow.sql.sql_plan import ( SqlSelectColumn, SqlSelectStatementNode, diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 90918cb8f4..d6c52dfbb0 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -37,6 +37,28 @@ from metricflow_semantics.specs.metric_spec import MetricSpec from metricflow_semantics.specs.spec_set import InstanceSpecSet from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec +from metricflow_semantics.sql.sql_exprs import ( + SqlAggregateFunctionExpression, + SqlBetweenExpression, + SqlColumnReference, + SqlColumnReferenceExpression, + SqlComparison, + SqlComparisonExpression, + SqlDateTruncExpression, + SqlExpressionNode, + SqlExtractExpression, + SqlFunction, + SqlFunctionExpression, + SqlGenerateUuidExpression, + SqlLogicalExpression, + SqlLogicalOperator, + SqlRatioComputationExpression, + SqlStringExpression, + SqlStringLiteralExpression, + SqlWindowFunction, + SqlWindowFunctionExpression, + SqlWindowOrderByArgument, +) from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.time.granularity import ExpandedTimeGranularity @@ -112,28 +134,6 @@ SqlQueryOptimizationLevel, ) from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer -from metricflow.sql.sql_exprs import ( - SqlAggregateFunctionExpression, - SqlBetweenExpression, - SqlColumnReference, - SqlColumnReferenceExpression, - SqlComparison, - SqlComparisonExpression, - SqlDateTruncExpression, - SqlExpressionNode, - SqlExtractExpression, - SqlFunction, - SqlFunctionExpression, - SqlGenerateUuidExpression, - SqlLogicalExpression, - SqlLogicalOperator, - SqlRatioComputationExpression, - SqlStringExpression, - SqlStringLiteralExpression, - SqlWindowFunction, - SqlWindowFunctionExpression, - SqlWindowOrderByArgument, -) from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, SqlCteNode, diff --git a/metricflow/plan_conversion/instance_converters.py b/metricflow/plan_conversion/instance_converters.py index b801d958d3..cb292a48eb 100644 --- a/metricflow/plan_conversion/instance_converters.py +++ b/metricflow/plan_conversion/instance_converters.py @@ -32,11 +32,7 @@ from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec from metricflow_semantics.specs.measure_spec import MeasureSpec, MetricInputMeasureSpec from metricflow_semantics.specs.spec_set import InstanceSpecSet -from more_itertools import bucket - -from metricflow.dataflow.nodes.join_to_base import ValidityWindowJoinDescription -from metricflow.plan_conversion.select_column_gen import SelectColumnSet -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlAggregateFunctionExpression, SqlColumnReference, SqlColumnReferenceExpression, @@ -45,6 +41,10 @@ SqlFunctionExpression, SqlStringExpression, ) +from more_itertools import bucket + +from metricflow.dataflow.nodes.join_to_base import ValidityWindowJoinDescription +from metricflow.plan_conversion.select_column_gen import SelectColumnSet from metricflow.sql.sql_plan import ( SqlSelectColumn, ) diff --git a/metricflow/plan_conversion/sql_expression_builders.py b/metricflow/plan_conversion/sql_expression_builders.py index e5ed18d463..26029788c0 100644 --- a/metricflow/plan_conversion/sql_expression_builders.py +++ b/metricflow/plan_conversion/sql_expression_builders.py @@ -1,9 +1,10 @@ """Utility module for building sql expressions from inputs derived from dataflow plan or other nodes.""" + from __future__ import annotations from typing import List, Sequence -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlAggregateFunctionExpression, SqlColumnReference, SqlColumnReferenceExpression, diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index 284e203609..f80cdf2287 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -7,15 +7,7 @@ from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set from metricflow_semantics.errors.custom_grain_not_supported import error_if_not_standard_grain -from metricflow_semantics.sql.sql_join_type import SqlJoinType - -from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode -from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode -from metricflow.dataflow.nodes.join_to_base import JoinDescription -from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode -from metricflow.dataset.sql_dataset import AnnotatedSqlDataSet -from metricflow.plan_conversion.sql_expression_builders import make_coalesced_expr -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlColumnReference, SqlColumnReferenceExpression, SqlComparison, @@ -26,6 +18,14 @@ SqlLogicalOperator, SqlSubtractTimeIntervalExpression, ) +from metricflow_semantics.sql.sql_join_type import SqlJoinType + +from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode +from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode +from metricflow.dataflow.nodes.join_to_base import JoinDescription +from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode +from metricflow.dataset.sql_dataset import AnnotatedSqlDataSet +from metricflow.plan_conversion.sql_expression_builders import make_coalesced_expr from metricflow.sql.sql_plan import SqlExpressionNode, SqlJoinDescription, SqlSelectStatementNode diff --git a/metricflow/sql/optimizer/required_column_aliases.py b/metricflow/sql/optimizer/required_column_aliases.py index 32dfacd32d..08022a6faa 100644 --- a/metricflow/sql/optimizer/required_column_aliases.py +++ b/metricflow/sql/optimizer/required_column_aliases.py @@ -5,10 +5,10 @@ from typing import Dict, FrozenSet, List, Set, Tuple from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat +from metricflow_semantics.sql.sql_exprs import SqlExpressionTreeLineage from typing_extensions import override from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping -from metricflow.sql.sql_exprs import SqlExpressionTreeLineage from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, SqlCteNode, diff --git a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py index bd4fbb87fb..82efa5dcd6 100644 --- a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py +++ b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py @@ -6,10 +6,7 @@ from metricflow_semantics.mf_logging.formatting import indent from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat -from typing_extensions import override - -from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlColumnAliasReferenceExpression, SqlColumnReference, SqlColumnReplacements, @@ -18,6 +15,9 @@ SqlLogicalExpression, SqlLogicalOperator, ) +from typing_extensions import override + +from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, SqlCteNode, diff --git a/metricflow/sql/optimizer/sub_query_reducer.py b/metricflow/sql/optimizer/sub_query_reducer.py index c223d5d3f7..9a930b99b2 100644 --- a/metricflow/sql/optimizer/sub_query_reducer.py +++ b/metricflow/sql/optimizer/sub_query_reducer.py @@ -3,10 +3,10 @@ import logging from typing import List, Optional +from metricflow_semantics.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression from typing_extensions import override from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer -from metricflow.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, SqlCteNode, diff --git a/metricflow/sql/render/big_query.py b/metricflow/sql/render/big_query.py index a63b2d06c6..5745ab2f8e 100644 --- a/metricflow/sql/render/big_query.py +++ b/metricflow/sql/render/big_query.py @@ -8,16 +8,7 @@ from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.errors.error_classes import UnsupportedEngineFeatureError from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet -from typing_extensions import override - -from metricflow.protocols.sql_client import SqlEngine -from metricflow.sql.render.expr_renderer import ( - DefaultSqlExpressionRenderer, - SqlExpressionRenderer, - SqlExpressionRenderResult, -) -from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlAddTimeExpression, SqlCastToTimestampExpression, SqlDateTruncExpression, @@ -27,6 +18,15 @@ SqlPercentileFunctionType, SqlSubtractTimeIntervalExpression, ) +from typing_extensions import override + +from metricflow.protocols.sql_client import SqlEngine +from metricflow.sql.render.expr_renderer import ( + DefaultSqlExpressionRenderer, + SqlExpressionRenderer, + SqlExpressionRenderResult, +) +from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from metricflow.sql.sql_plan import SqlSelectColumn diff --git a/metricflow/sql/render/databricks.py b/metricflow/sql/render/databricks.py index 5aa98000a7..2b7dbe6366 100644 --- a/metricflow/sql/render/databricks.py +++ b/metricflow/sql/render/databricks.py @@ -5,6 +5,7 @@ from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.type_enums.date_part import DatePart from metricflow_semantics.errors.error_classes import UnsupportedEngineFeatureError +from metricflow_semantics.sql.sql_exprs import SqlPercentileExpression, SqlPercentileFunctionType from typing_extensions import override from metricflow.protocols.sql_client import SqlEngine @@ -14,7 +15,6 @@ SqlExpressionRenderResult, ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer -from metricflow.sql.sql_exprs import SqlPercentileExpression, SqlPercentileFunctionType class DatabricksSqlExpressionRenderer(DefaultSqlExpressionRenderer): diff --git a/metricflow/sql/render/duckdb_renderer.py b/metricflow/sql/render/duckdb_renderer.py index 3e03b7eca5..ecfca54f52 100644 --- a/metricflow/sql/render/duckdb_renderer.py +++ b/metricflow/sql/render/duckdb_renderer.py @@ -5,6 +5,13 @@ from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet +from metricflow_semantics.sql.sql_exprs import ( + SqlAddTimeExpression, + SqlGenerateUuidExpression, + SqlPercentileExpression, + SqlPercentileFunctionType, + SqlSubtractTimeIntervalExpression, +) from typing_extensions import override from metricflow.protocols.sql_client import SqlEngine @@ -14,13 +21,6 @@ SqlExpressionRenderResult, ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer -from metricflow.sql.sql_exprs import ( - SqlAddTimeExpression, - SqlGenerateUuidExpression, - SqlPercentileExpression, - SqlPercentileFunctionType, - SqlSubtractTimeIntervalExpression, -) class DuckDbSqlExpressionRenderer(DefaultSqlExpressionRenderer): diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index a387a2da0d..10e3d748b9 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -12,10 +12,7 @@ from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.mf_logging.formatting import indent from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet -from typing_extensions import override - -from metricflow.sql.render.rendering_constants import SqlRenderingConstants -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlAddTimeExpression, SqlAggregateFunctionExpression, SqlBetweenExpression, @@ -40,6 +37,9 @@ SqlSubtractTimeIntervalExpression, SqlWindowFunctionExpression, ) +from typing_extensions import override + +from metricflow.sql.render.rendering_constants import SqlRenderingConstants from metricflow.sql.sql_plan import SqlSelectColumn if TYPE_CHECKING: diff --git a/metricflow/sql/render/postgres.py b/metricflow/sql/render/postgres.py index 8ced430207..2509dfc243 100644 --- a/metricflow/sql/render/postgres.py +++ b/metricflow/sql/render/postgres.py @@ -6,6 +6,13 @@ from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.errors.error_classes import UnsupportedEngineFeatureError from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet +from metricflow_semantics.sql.sql_exprs import ( + SqlAddTimeExpression, + SqlGenerateUuidExpression, + SqlPercentileExpression, + SqlPercentileFunctionType, + SqlSubtractTimeIntervalExpression, +) from typing_extensions import override from metricflow.protocols.sql_client import SqlEngine @@ -15,13 +22,6 @@ SqlExpressionRenderResult, ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer -from metricflow.sql.sql_exprs import ( - SqlAddTimeExpression, - SqlGenerateUuidExpression, - SqlPercentileExpression, - SqlPercentileFunctionType, - SqlSubtractTimeIntervalExpression, -) class PostgresSqlExpressionRenderer(DefaultSqlExpressionRenderer): diff --git a/metricflow/sql/render/redshift.py b/metricflow/sql/render/redshift.py index c916cd3a1c..46c3385984 100644 --- a/metricflow/sql/render/redshift.py +++ b/metricflow/sql/render/redshift.py @@ -6,6 +6,12 @@ from dbt_semantic_interfaces.type_enums.date_part import DatePart from metricflow_semantics.errors.error_classes import UnsupportedEngineFeatureError from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet +from metricflow_semantics.sql.sql_exprs import ( + SqlExtractExpression, + SqlGenerateUuidExpression, + SqlPercentileExpression, + SqlPercentileFunctionType, +) from typing_extensions import override from metricflow.protocols.sql_client import SqlEngine @@ -15,12 +21,6 @@ SqlExpressionRenderResult, ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer -from metricflow.sql.sql_exprs import ( - SqlExtractExpression, - SqlGenerateUuidExpression, - SqlPercentileExpression, - SqlPercentileFunctionType, -) class RedshiftSqlExpressionRenderer(DefaultSqlExpressionRenderer): diff --git a/metricflow/sql/render/snowflake.py b/metricflow/sql/render/snowflake.py index bc125087e8..04aabeee48 100644 --- a/metricflow/sql/render/snowflake.py +++ b/metricflow/sql/render/snowflake.py @@ -6,6 +6,11 @@ from dbt_semantic_interfaces.type_enums.date_part import DatePart from metricflow_semantics.errors.error_classes import UnsupportedEngineFeatureError from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet +from metricflow_semantics.sql.sql_exprs import ( + SqlGenerateUuidExpression, + SqlPercentileExpression, + SqlPercentileFunctionType, +) from typing_extensions import override from metricflow.protocols.sql_client import SqlEngine @@ -15,11 +20,6 @@ SqlExpressionRenderResult, ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer -from metricflow.sql.sql_exprs import ( - SqlGenerateUuidExpression, - SqlPercentileExpression, - SqlPercentileFunctionType, -) class SnowflakeSqlExpressionRenderer(DefaultSqlExpressionRenderer): diff --git a/metricflow/sql/render/sql_plan_renderer.py b/metricflow/sql/render/sql_plan_renderer.py index 9e0fb11f6e..59e51cca82 100644 --- a/metricflow/sql/render/sql_plan_renderer.py +++ b/metricflow/sql/render/sql_plan_renderer.py @@ -9,6 +9,7 @@ from metricflow_semantics.mf_logging.formatting import indent from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet +from metricflow_semantics.sql.sql_exprs import SqlExpressionNode from typing_extensions import override from metricflow.sql.render.expr_renderer import ( @@ -17,7 +18,6 @@ SqlExpressionRenderResult, ) from metricflow.sql.render.rendering_constants import SqlRenderingConstants -from metricflow.sql.sql_exprs import SqlExpressionNode from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, SqlCteNode, diff --git a/metricflow/sql/render/trino.py b/metricflow/sql/render/trino.py index e6aff3150a..bd3a581597 100644 --- a/metricflow/sql/render/trino.py +++ b/metricflow/sql/render/trino.py @@ -7,6 +7,14 @@ from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet +from metricflow_semantics.sql.sql_exprs import ( + SqlAddTimeExpression, + SqlBetweenExpression, + SqlGenerateUuidExpression, + SqlPercentileExpression, + SqlPercentileFunctionType, + SqlSubtractTimeIntervalExpression, +) from typing_extensions import override from metricflow.protocols.sql_client import SqlEngine @@ -16,14 +24,6 @@ SqlExpressionRenderResult, ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer -from metricflow.sql.sql_exprs import ( - SqlAddTimeExpression, - SqlBetweenExpression, - SqlGenerateUuidExpression, - SqlPercentileExpression, - SqlPercentileFunctionType, - SqlSubtractTimeIntervalExpression, -) class TrinoSqlExpressionRenderer(DefaultSqlExpressionRenderer): diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index 792498cc2e..ff0b34c650 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -9,13 +9,12 @@ from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagId, DagNode, DisplayedProperty, MetricFlowDag +from metricflow_semantics.sql.sql_exprs import SqlExpressionNode from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.visitor import VisitorOutputT from typing_extensions import override -from metricflow.sql.sql_exprs import SqlExpressionNode - logger = logging.getLogger(__name__) diff --git a/tests_metricflow/dataflow/builder/test_node_data_set.py b/tests_metricflow/dataflow/builder/test_node_data_set.py index c62c255e55..5e369b3386 100644 --- a/tests_metricflow/dataflow/builder/test_node_data_set.py +++ b/tests_metricflow/dataflow/builder/test_node_data_set.py @@ -16,6 +16,7 @@ from metricflow_semantics.specs.dunder_column_association_resolver import DunderColumnAssociationResolver from metricflow_semantics.specs.entity_spec import LinklessEntitySpec from metricflow_semantics.specs.measure_spec import MeasureSpec +from metricflow_semantics.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration @@ -26,7 +27,6 @@ from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataset.sql_dataset import SqlDataSet -from metricflow.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression from metricflow.sql.sql_plan import ( SqlSelectColumn, SqlSelectStatementNode, diff --git a/tests_metricflow/integration/test_configured_cases.py b/tests_metricflow/integration/test_configured_cases.py index c7763bbb43..027834d059 100644 --- a/tests_metricflow/integration/test_configured_cases.py +++ b/tests_metricflow/integration/test_configured_cases.py @@ -15,13 +15,7 @@ from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat from metricflow_semantics.protocols.query_parameter import DimensionOrEntityQueryParameter from metricflow_semantics.specs.query_param_implementations import DimensionOrEntityParameter, TimeDimensionParameter -from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration -from metricflow_semantics.time.time_constants import ISO8601_PYTHON_FORMAT, ISO8601_PYTHON_TS_FORMAT -from metricflow_semantics.time.time_spine_source import TimeSpineSource - -from metricflow.engine.metricflow_engine import MetricFlowQueryRequest -from metricflow.protocols.sql_client import SqlClient -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlCastToTimestampExpression, SqlColumnReference, SqlColumnReferenceExpression, @@ -34,6 +28,12 @@ SqlStringExpression, SqlSubtractTimeIntervalExpression, ) +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration +from metricflow_semantics.time.time_constants import ISO8601_PYTHON_FORMAT, ISO8601_PYTHON_TS_FORMAT +from metricflow_semantics.time.time_spine_source import TimeSpineSource + +from metricflow.engine.metricflow_engine import MetricFlowQueryRequest +from metricflow.protocols.sql_client import SqlClient from tests_metricflow.fixtures.manifest_fixtures import MetricFlowEngineTestFixture, SemanticManifestSetup from tests_metricflow.integration.configured_test_case import ( CONFIGURED_INTEGRATION_TESTS_REPOSITORY, @@ -280,31 +280,33 @@ def test_case( limit=case.limit, time_constraint_start=parser.parse(case.time_constraint[0]) if case.time_constraint else None, time_constraint_end=parser.parse(case.time_constraint[1]) if case.time_constraint else None, - where_constraints=[ - jinja2.Template( - case.where_filter, - undefined=jinja2.StrictUndefined, - ).render( - source_schema=mf_test_configuration.mf_source_schema, - render_time_constraint=check_query_helpers.render_time_constraint, - TimeGranularity=TimeGranularity, - DatePart=DatePart, - render_date_sub=check_query_helpers.render_date_sub, - render_date_trunc=check_query_helpers.render_date_trunc, - render_extract=check_query_helpers.render_extract, - render_percentile_expr=check_query_helpers.render_percentile_expr, - mf_time_spine_source=time_spine_source.spine_table.sql, - double_data_type_name=check_query_helpers.double_data_type_name, - render_dimension_template=check_query_helpers.render_dimension_template, - render_entity_template=check_query_helpers.render_entity_template, - render_metric_template=check_query_helpers.render_metric_template, - render_time_dimension_template=check_query_helpers.render_time_dimension_template, - generate_random_uuid=check_query_helpers.generate_random_uuid, - cast_to_ts=check_query_helpers.cast_to_ts, - ) - ] - if case.where_filter - else None, + where_constraints=( + [ + jinja2.Template( + case.where_filter, + undefined=jinja2.StrictUndefined, + ).render( + source_schema=mf_test_configuration.mf_source_schema, + render_time_constraint=check_query_helpers.render_time_constraint, + TimeGranularity=TimeGranularity, + DatePart=DatePart, + render_date_sub=check_query_helpers.render_date_sub, + render_date_trunc=check_query_helpers.render_date_trunc, + render_extract=check_query_helpers.render_extract, + render_percentile_expr=check_query_helpers.render_percentile_expr, + mf_time_spine_source=time_spine_source.spine_table.sql, + double_data_type_name=check_query_helpers.double_data_type_name, + render_dimension_template=check_query_helpers.render_dimension_template, + render_entity_template=check_query_helpers.render_entity_template, + render_metric_template=check_query_helpers.render_metric_template, + render_time_dimension_template=check_query_helpers.render_time_dimension_template, + generate_random_uuid=check_query_helpers.generate_random_uuid, + cast_to_ts=check_query_helpers.cast_to_ts, + ) + ] + if case.where_filter + else None + ), order_by_names=case.order_bys, min_max_only=case.min_max_only, ) diff --git a/tests_metricflow/mf_logging/test_dag_to_text.py b/tests_metricflow/mf_logging/test_dag_to_text.py index c7a8a3c88f..d69c2a12d4 100644 --- a/tests_metricflow/mf_logging/test_dag_to_text.py +++ b/tests_metricflow/mf_logging/test_dag_to_text.py @@ -10,11 +10,11 @@ from metricflow_semantics.dag.mf_dag import DagId from metricflow_semantics.mf_logging.formatting import indent from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat -from metricflow_semantics.sql.sql_table import SqlTable - -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlStringExpression, ) +from metricflow_semantics.sql.sql_table import SqlTable + from metricflow.sql.sql_plan import ( SqlPlan, SqlSelectColumn, diff --git a/tests_metricflow/plan_conversion/instance_converters/test_create_select_columns_with_measures_aggregated.py b/tests_metricflow/plan_conversion/instance_converters/test_create_select_columns_with_measures_aggregated.py index 68ca855c4e..1c188dffe2 100644 --- a/tests_metricflow/plan_conversion/instance_converters/test_create_select_columns_with_measures_aggregated.py +++ b/tests_metricflow/plan_conversion/instance_converters/test_create_select_columns_with_measures_aggregated.py @@ -7,17 +7,17 @@ from metricflow_semantics.specs.dunder_column_association_resolver import DunderColumnAssociationResolver from metricflow_semantics.specs.measure_spec import MeasureSpec, MetricInputMeasureSpec from metricflow_semantics.specs.spec_set import InstanceSpecSet +from metricflow_semantics.sql.sql_exprs import ( + SqlAggregateFunctionExpression, + SqlFunction, + SqlPercentileExpression, +) from metricflow.plan_conversion.instance_converters import ( CreateSelectColumnsWithMeasuresAggregated, FilterElements, ) from metricflow.plan_conversion.select_column_gen import SelectColumnSet -from metricflow.sql.sql_exprs import ( - SqlAggregateFunctionExpression, - SqlFunction, - SqlPercentileExpression, -) from tests_metricflow.fixtures.manifest_fixtures import MetricFlowEngineTestFixture, SemanticManifestSetup __SOURCE_TABLE_ALIAS = "a" diff --git a/tests_metricflow/sql/optimizer/test_column_pruner.py b/tests_metricflow/sql/optimizer/test_column_pruner.py index 01760be80a..dc66a2c85f 100644 --- a/tests_metricflow/sql/optimizer/test_column_pruner.py +++ b/tests_metricflow/sql/optimizer/test_column_pruner.py @@ -5,13 +5,7 @@ import pytest from _pytest.fixtures import FixtureRequest from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat -from metricflow_semantics.sql.sql_join_type import SqlJoinType -from metricflow_semantics.sql.sql_table import SqlTable -from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration - -from metricflow.sql.optimizer.column_pruner import SqlColumnPrunerOptimizer -from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer, SqlQueryPlanRenderer -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlColumnReference, SqlColumnReferenceExpression, SqlComparison, @@ -19,6 +13,12 @@ SqlIsNullExpression, SqlStringExpression, ) +from metricflow_semantics.sql.sql_join_type import SqlJoinType +from metricflow_semantics.sql.sql_table import SqlTable +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration + +from metricflow.sql.optimizer.column_pruner import SqlColumnPrunerOptimizer +from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer, SqlQueryPlanRenderer from metricflow.sql.sql_plan import ( SqlJoinDescription, SqlQueryPlanNode, diff --git a/tests_metricflow/sql/optimizer/test_cte_column_pruner.py b/tests_metricflow/sql/optimizer/test_cte_column_pruner.py index 3300ef8216..e8af9117a0 100644 --- a/tests_metricflow/sql/optimizer/test_cte_column_pruner.py +++ b/tests_metricflow/sql/optimizer/test_cte_column_pruner.py @@ -4,18 +4,18 @@ import pytest from _pytest.fixtures import FixtureRequest +from metricflow_semantics.sql.sql_exprs import ( + SqlColumnReference, + SqlColumnReferenceExpression, + SqlComparison, + SqlComparisonExpression, +) from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow.sql.optimizer.column_pruner import SqlColumnPrunerOptimizer from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer, SqlQueryPlanRenderer -from metricflow.sql.sql_exprs import ( - SqlColumnReference, - SqlColumnReferenceExpression, - SqlComparison, - SqlComparisonExpression, -) from metricflow.sql.sql_plan import ( SqlCteNode, SqlJoinDescription, diff --git a/tests_metricflow/sql/optimizer/test_cte_rewriting_sub_query_reducer.py b/tests_metricflow/sql/optimizer/test_cte_rewriting_sub_query_reducer.py index dbdec0e4ff..33b04e7d9d 100644 --- a/tests_metricflow/sql/optimizer/test_cte_rewriting_sub_query_reducer.py +++ b/tests_metricflow/sql/optimizer/test_cte_rewriting_sub_query_reducer.py @@ -4,18 +4,18 @@ import pytest from _pytest.fixtures import FixtureRequest +from metricflow_semantics.sql.sql_exprs import ( + SqlColumnReference, + SqlColumnReferenceExpression, + SqlComparison, + SqlComparisonExpression, +) from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow.sql.optimizer.rewriting_sub_query_reducer import SqlRewritingSubQueryReducer from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer, SqlQueryPlanRenderer -from metricflow.sql.sql_exprs import ( - SqlColumnReference, - SqlColumnReferenceExpression, - SqlComparison, - SqlComparisonExpression, -) from metricflow.sql.sql_plan import ( SqlCteNode, SqlJoinDescription, diff --git a/tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py b/tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py index f2794f5ecd..aa631c27b9 100644 --- a/tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py +++ b/tests_metricflow/sql/optimizer/test_cte_table_alias_simplifier.py @@ -2,18 +2,18 @@ import pytest from _pytest.fixtures import FixtureRequest +from metricflow_semantics.sql.sql_exprs import ( + SqlColumnReference, + SqlColumnReferenceExpression, + SqlComparison, + SqlComparisonExpression, +) from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow.sql.optimizer.table_alias_simplifier import SqlTableAliasSimplifier from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer -from metricflow.sql.sql_exprs import ( - SqlColumnReference, - SqlColumnReferenceExpression, - SqlComparison, - SqlComparisonExpression, -) from metricflow.sql.sql_plan import ( SqlCteNode, SqlJoinDescription, diff --git a/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py b/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py index 221594f367..5c40afb77a 100644 --- a/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py +++ b/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py @@ -2,12 +2,7 @@ import pytest from _pytest.fixtures import FixtureRequest -from metricflow_semantics.sql.sql_join_type import SqlJoinType -from metricflow_semantics.sql.sql_table import SqlTable -from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration - -from metricflow.sql.optimizer.rewriting_sub_query_reducer import SqlRewritingSubQueryReducer -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlAggregateFunctionExpression, SqlColumnReference, SqlColumnReferenceExpression, @@ -17,6 +12,11 @@ SqlStringExpression, SqlStringLiteralExpression, ) +from metricflow_semantics.sql.sql_join_type import SqlJoinType +from metricflow_semantics.sql.sql_table import SqlTable +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration + +from metricflow.sql.optimizer.rewriting_sub_query_reducer import SqlRewritingSubQueryReducer from metricflow.sql.sql_plan import ( SqlJoinDescription, SqlOrderByDescription, diff --git a/tests_metricflow/sql/optimizer/test_sub_query_reducer.py b/tests_metricflow/sql/optimizer/test_sub_query_reducer.py index 30726e8ac4..991e7c86ff 100644 --- a/tests_metricflow/sql/optimizer/test_sub_query_reducer.py +++ b/tests_metricflow/sql/optimizer/test_sub_query_reducer.py @@ -2,17 +2,17 @@ import pytest from _pytest.fixtures import FixtureRequest -from metricflow_semantics.sql.sql_join_type import SqlJoinType -from metricflow_semantics.sql.sql_table import SqlTable -from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration - -from metricflow.sql.optimizer.sub_query_reducer import SqlSubQueryReducer -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlColumnReference, SqlColumnReferenceExpression, SqlComparison, SqlComparisonExpression, ) +from metricflow_semantics.sql.sql_join_type import SqlJoinType +from metricflow_semantics.sql.sql_table import SqlTable +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration + +from metricflow.sql.optimizer.sub_query_reducer import SqlSubQueryReducer from metricflow.sql.sql_plan import ( SqlJoinDescription, SqlOrderByDescription, diff --git a/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py b/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py index 1615b3795b..7a1de37504 100644 --- a/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py +++ b/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py @@ -2,18 +2,18 @@ import pytest from _pytest.fixtures import FixtureRequest +from metricflow_semantics.sql.sql_exprs import ( + SqlColumnReference, + SqlColumnReferenceExpression, + SqlComparison, + SqlComparisonExpression, +) from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow.sql.optimizer.table_alias_simplifier import SqlTableAliasSimplifier from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer, SqlQueryPlanRenderer -from metricflow.sql.sql_exprs import ( - SqlColumnReference, - SqlColumnReferenceExpression, - SqlComparison, - SqlComparisonExpression, -) from metricflow.sql.sql_plan import ( SqlJoinDescription, SqlSelectColumn, diff --git a/tests_metricflow/sql/test_engine_specific_rendering.py b/tests_metricflow/sql/test_engine_specific_rendering.py index 987762006b..adc52cbcc4 100644 --- a/tests_metricflow/sql/test_engine_specific_rendering.py +++ b/tests_metricflow/sql/test_engine_specific_rendering.py @@ -5,11 +5,7 @@ import pytest from _pytest.fixtures import FixtureRequest from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity -from metricflow_semantics.sql.sql_table import SqlTable -from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration - -from metricflow.protocols.sql_client import SqlClient -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlAddTimeExpression, SqlCastToTimestampExpression, SqlColumnReference, @@ -21,6 +17,10 @@ SqlStringExpression, SqlStringLiteralExpression, ) +from metricflow_semantics.sql.sql_table import SqlTable +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration + +from metricflow.protocols.sql_client import SqlClient from metricflow.sql.sql_plan import ( SqlJoinDescription, SqlOrderByDescription, diff --git a/tests_metricflow/sql/test_render_cte.py b/tests_metricflow/sql/test_render_cte.py index 326804327b..db32aa68de 100644 --- a/tests_metricflow/sql/test_render_cte.py +++ b/tests_metricflow/sql/test_render_cte.py @@ -3,16 +3,16 @@ import logging from _pytest.fixtures import FixtureRequest -from metricflow_semantics.sql.sql_join_type import SqlJoinType -from metricflow_semantics.sql.sql_table import SqlTable -from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration - -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlColumnReference, SqlColumnReferenceExpression, SqlComparison, SqlComparisonExpression, ) +from metricflow_semantics.sql.sql_join_type import SqlJoinType +from metricflow_semantics.sql.sql_table import SqlTable +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration + from metricflow.sql.sql_plan import ( SqlCteNode, SqlJoinDescription, diff --git a/tests_metricflow/sql/test_sql_expr_render.py b/tests_metricflow/sql/test_sql_expr_render.py index 05b4103ea2..f2a578e8b2 100644 --- a/tests_metricflow/sql/test_sql_expr_render.py +++ b/tests_metricflow/sql/test_sql_expr_render.py @@ -8,10 +8,7 @@ from _pytest.fixtures import FixtureRequest from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity -from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration - -from metricflow.sql.render.expr_renderer import DefaultSqlExpressionRenderer -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlAggregateFunctionExpression, SqlBetweenExpression, SqlCastToTimestampExpression, @@ -34,6 +31,9 @@ SqlWindowFunctionExpression, SqlWindowOrderByArgument, ) +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration + +from metricflow.sql.render.expr_renderer import DefaultSqlExpressionRenderer from tests_metricflow.snapshot_utils import assert_str_snapshot_equal logger = logging.getLogger(__name__) diff --git a/tests_metricflow/sql/test_sql_plan_render.py b/tests_metricflow/sql/test_sql_plan_render.py index c14d67db50..7e09d83bb8 100644 --- a/tests_metricflow/sql/test_sql_plan_render.py +++ b/tests_metricflow/sql/test_sql_plan_render.py @@ -5,12 +5,7 @@ import pytest from _pytest.fixtures import FixtureRequest -from metricflow_semantics.sql.sql_join_type import SqlJoinType -from metricflow_semantics.sql.sql_table import SqlTable, SqlTableType -from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration - -from metricflow.protocols.sql_client import SqlClient -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlAggregateFunctionExpression, SqlColumnReference, SqlColumnReferenceExpression, @@ -19,6 +14,11 @@ SqlFunction, SqlStringExpression, ) +from metricflow_semantics.sql.sql_join_type import SqlJoinType +from metricflow_semantics.sql.sql_table import SqlTable, SqlTableType +from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration + +from metricflow.protocols.sql_client import SqlClient from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, SqlJoinDescription, diff --git a/tests_metricflow/sql_clients/test_date_time_operations.py b/tests_metricflow/sql_clients/test_date_time_operations.py index 0370ade60e..b0453b68ff 100644 --- a/tests_metricflow/sql_clients/test_date_time_operations.py +++ b/tests_metricflow/sql_clients/test_date_time_operations.py @@ -23,16 +23,16 @@ import pytest from dbt_semantic_interfaces.type_enums import TimeGranularity from dbt_semantic_interfaces.type_enums.date_part import DatePart - -from metricflow.data_table.mf_table import MetricFlowDataTable -from metricflow.protocols.sql_client import SqlClient -from metricflow.sql.sql_exprs import ( +from metricflow_semantics.sql.sql_exprs import ( SqlCastToTimestampExpression, SqlDateTruncExpression, SqlExtractExpression, SqlStringLiteralExpression, ) +from metricflow.data_table.mf_table import MetricFlowDataTable +from metricflow.protocols.sql_client import SqlClient + logger = logging.getLogger(__name__) From b1435ef2ac0678ab6709478d778ff956a6c76695 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 18 Dec 2024 09:31:26 -0800 Subject: [PATCH 2/4] Add SQL exprs needed for custom offset window This includes a CASE expression, an integer expression, and some updates to the add time expression & the window function expression. --- .../metricflow_semantics/dag/id_prefix.py | 3 + .../metricflow_semantics/sql/sql_exprs.py | 237 +++++++++++++++++- metricflow/sql/render/big_query.py | 4 +- metricflow/sql/render/duckdb_renderer.py | 23 +- metricflow/sql/render/expr_renderer.py | 48 +++- metricflow/sql/render/postgres.py | 23 +- metricflow/sql/render/trino.py | 23 +- 7 files changed, 331 insertions(+), 30 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py index 8c2a6d1b4e..61fdcc5f7d 100644 --- a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py +++ b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py @@ -75,6 +75,9 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper): SQL_EXPR_BETWEEN_PREFIX = "betw" SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc" SQL_EXPR_GENERATE_UUID_PREFIX = "uuid" + SQL_EXPR_CASE_PREFIX = "case" + SQL_EXPR_ARITHMETIC_PREFIX = "arit" + SQL_EXPR_INTEGER_PREFIX = "int" SQL_PLAN_SELECT_STATEMENT_ID_PREFIX = "ss" SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc" diff --git a/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py b/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py index ec7866f001..926ea5a397 100644 --- a/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py +++ b/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py @@ -14,12 +14,13 @@ from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity +from typing_extensions import override + from metricflow_semantics.collection_helpers.merger import Mergeable from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.visitor import Visitable, VisitorOutputT -from typing_extensions import override @dataclass(frozen=True, eq=False) @@ -237,6 +238,18 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> Visit def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOutputT: # noqa: D102 pass + @abstractmethod + def visit_case_expr(self, node: SqlCaseExpression) -> VisitorOutputT: # noqa: D102 + pass + + @abstractmethod + def visit_arithmetic_expr(self, node: SqlArithmeticExpression) -> VisitorOutputT: # noqa: D102 + pass + + @abstractmethod + def visit_integer_expr(self, node: SqlIntegerExpression) -> VisitorOutputT: # noqa: D102 + pass + @dataclass(frozen=True, eq=False) class SqlStringExpression(SqlExpressionNode): @@ -375,6 +388,59 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.literal_value == other.literal_value +@dataclass(frozen=True, eq=False) +class SqlIntegerExpression(SqlExpressionNode): + """An integer like 1.""" + + integer_value: int + + @staticmethod + def create(integer_value: int) -> SqlIntegerExpression: # noqa: D102 + return SqlIntegerExpression(parent_nodes=(), integer_value=integer_value) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.SQL_EXPR_INTEGER_PREFIX + + def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_integer_expr(self) + + @property + def description(self) -> str: # noqa: D102 + return f"Integer: {self.integer_value}" + + @property + def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 + return tuple(super().displayed_properties) + (DisplayedProperty("value", self.integer_value),) + + @property + def requires_parenthesis(self) -> bool: # noqa: D102 + return False + + @property + def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102 + return SqlBindParameterSet() + + def __repr__(self) -> str: # noqa: D105 + return f"{self.__class__.__name__}(node_id={self.node_id}, integer_value={self.integer_value})" + + def rewrite( # noqa: D102 + self, + column_replacements: Optional[SqlColumnReplacements] = None, + should_render_table_alias: Optional[bool] = None, + ) -> SqlExpressionNode: + return self + + @property + def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 + return SqlExpressionTreeLineage(other_exprs=(self,)) + + def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 + if not isinstance(other, SqlIntegerExpression): + return False + return self.integer_value == other.integer_value + + @dataclass(frozen=True) class SqlColumnReference: """Used with string expressions to specify what columns are referred to in the string expression.""" @@ -950,11 +1016,18 @@ class SqlWindowFunction(Enum): FIRST_VALUE = "FIRST_VALUE" LAST_VALUE = "LAST_VALUE" AVERAGE = "AVG" + ROW_NUMBER = "ROW_NUMBER" + LAG = "LAG" @property def requires_ordering(self) -> bool: """Asserts whether or not ordering the window function will have an impact on the resulting value.""" - if self is SqlWindowFunction.FIRST_VALUE or self is SqlWindowFunction.LAST_VALUE: + if ( + self is SqlWindowFunction.FIRST_VALUE + or self is SqlWindowFunction.LAST_VALUE + or self is SqlWindowFunction.ROW_NUMBER + or self is SqlWindowFunction.LAG + ): return True elif self is SqlWindowFunction.AVERAGE: return False @@ -1106,7 +1179,8 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return ( self.sql_function == other.sql_function and self.order_by_args == other.order_by_args - and self._parents_match(other) + and self.partition_by_args == other.partition_by_args + and self.sql_function_args == other.sql_function_args ) @@ -1367,7 +1441,7 @@ def rewrite( # noqa: D102 ) -> SqlExpressionNode: return SqlAddTimeExpression.create( arg=self.arg.rewrite(column_replacements, should_render_table_alias), - count_expr=self.count_expr, + count_expr=self.count_expr.rewrite(column_replacements, should_render_table_alias), granularity=self.granularity, ) @@ -1719,3 +1793,158 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False + + +@dataclass(frozen=True, eq=False) +class SqlCaseExpression(SqlExpressionNode): + """Renders a CASE WHEN expression.""" + + when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode] + else_expr: Optional[SqlExpressionNode] + + @staticmethod + def create( # noqa: D102 + when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode], else_expr: Optional[SqlExpressionNode] = None + ) -> SqlCaseExpression: + parent_nodes: Tuple[SqlExpressionNode, ...] = () + for when, then in when_to_then_exprs.items(): + parent_nodes += (when,) + parent_nodes += (then,) + + if else_expr: + parent_nodes += (else_expr,) + + return SqlCaseExpression(parent_nodes=parent_nodes, when_to_then_exprs=when_to_then_exprs, else_expr=else_expr) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.SQL_EXPR_CASE_PREFIX + + def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_case_expr(self) + + @property + def description(self) -> str: # noqa: D102 + return "Case expression" + + @property + def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 + return super().displayed_properties + + @property + def requires_parenthesis(self) -> bool: # noqa: D102 + return False + + @property + def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102 + return SqlBindParameterSet() + + def __repr__(self) -> str: # noqa: D105 + return f"{self.__class__.__name__}(node_id={self.node_id})" + + def rewrite( # noqa: D102 + self, + column_replacements: Optional[SqlColumnReplacements] = None, + should_render_table_alias: Optional[bool] = None, + ) -> SqlExpressionNode: + return SqlCaseExpression.create( + when_to_then_exprs={ + when.rewrite(column_replacements, should_render_table_alias): then.rewrite( + column_replacements, should_render_table_alias + ) + for when, then in self.when_to_then_exprs.items() + }, + else_expr=( + self.else_expr.rewrite(column_replacements, should_render_table_alias) if self.else_expr else None + ), + ) + + @property + def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 + return SqlExpressionTreeLineage.merge_iterable( + tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) + ) + + def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 + if not isinstance(other, SqlCaseExpression): + return False + return self.when_to_then_exprs == other.when_to_then_exprs and self.else_expr == other.else_expr + + +class SqlArithmeticOperator(Enum): + """Arithmetic operator used to do math in a SQL expression.""" + + ADD = "+" + SUBTRACT = "-" + MULTIPLY = "*" + DIVIDE = "/" + + +@dataclass(frozen=True, eq=False) +class SqlArithmeticExpression(SqlExpressionNode): + """An arithmetic expression using +, -, *, /. + + e.g. my_table.my_column + my_table.other_column + + Attributes: + left_expr: The expression on the left side of the operator + operator: The operator to use on the expressions + right_expr: The expression on the right side of the operator + """ + + left_expr: SqlExpressionNode + operator: SqlArithmeticOperator + right_expr: SqlExpressionNode + + @staticmethod + def create( # noqa: D102 + left_expr: SqlExpressionNode, operator: SqlArithmeticOperator, right_expr: SqlExpressionNode + ) -> SqlArithmeticExpression: + return SqlArithmeticExpression( + parent_nodes=(left_expr, right_expr), left_expr=left_expr, operator=operator, right_expr=right_expr + ) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.SQL_EXPR_ARITHMETIC_PREFIX + + def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_arithmetic_expr(self) + + @property + def description(self) -> str: # noqa: D102 + return "Arithmetic Expression" + + @property + def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 + return tuple(super().displayed_properties) + ( + DisplayedProperty("left_expr", self.left_expr), + DisplayedProperty("operator", self.operator.value), + DisplayedProperty("right_expr", self.right_expr), + ) + + @property + def requires_parenthesis(self) -> bool: # noqa: D102 + return True + + def rewrite( # noqa: D102 + self, + column_replacements: Optional[SqlColumnReplacements] = None, + should_render_table_alias: Optional[bool] = None, + ) -> SqlExpressionNode: + return SqlArithmeticExpression.create( + left_expr=self.left_expr.rewrite(column_replacements, should_render_table_alias), + operator=self.operator, + right_expr=self.right_expr.rewrite(column_replacements, should_render_table_alias), + ) + + @property + def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 + return SqlExpressionTreeLineage.merge_iterable( + tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) + ) + + def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 + if not isinstance(other, SqlArithmeticExpression): + return False + return self.operator == other.operator and self._parents_match(other) diff --git a/metricflow/sql/render/big_query.py b/metricflow/sql/render/big_query.py index 5745ab2f8e..186f231108 100644 --- a/metricflow/sql/render/big_query.py +++ b/metricflow/sql/render/big_query.py @@ -184,8 +184,8 @@ def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRender count = node.count_expr.accept(self) return SqlExpressionRenderResult( - sql=f"DATE_ADD(CAST({column.sql} AS {self.timestamp_data_type}), INTERVAL {count} {node.granularity.value})", - bind_parameter_set=column.bind_parameter_set, + sql=f"DATE_ADD(CAST({column.sql} AS {self.timestamp_data_type}), INTERVAL {count.sql} {node.granularity.value})", + bind_parameter_set=column.bind_parameter_set.merge(count.bind_parameter_set), ) @override diff --git a/metricflow/sql/render/duckdb_renderer.py b/metricflow/sql/render/duckdb_renderer.py index ecfca54f52..48d0c16722 100644 --- a/metricflow/sql/render/duckdb_renderer.py +++ b/metricflow/sql/render/duckdb_renderer.py @@ -7,7 +7,10 @@ from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.sql.sql_exprs import ( SqlAddTimeExpression, + SqlArithmeticExpression, + SqlArithmeticOperator, SqlGenerateUuidExpression, + SqlIntegerExpression, SqlPercentileExpression, SqlPercentileFunctionType, SqlSubtractTimeIntervalExpression, @@ -56,17 +59,25 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress @override def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: """Render time delta expression for DuckDB, which requires slightly different syntax from other engines.""" - arg_rendered = node.arg.accept(self) - count_rendered = node.count_expr.accept(self).sql - granularity = node.granularity + count_expr = node.count_expr if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH - count_rendered = f"({count_rendered} * 3)" + count_expr = SqlArithmeticExpression.create( + left_expr=node.count_expr, + operator=SqlArithmeticOperator.MULTIPLY, + right_expr=SqlIntegerExpression.create(3), + ) + + arg_rendered = node.arg.accept(self) + count_rendered = count_expr.accept(self) + count_sql = f"({count_rendered.sql})" if count_expr.requires_parenthesis else count_rendered.sql return SqlExpressionRenderResult( - sql=f"{arg_rendered.sql} + INTERVAL {count_rendered} {granularity.value}", - bind_parameter_set=arg_rendered.bind_parameter_set, + sql=f"{arg_rendered.sql} + INTERVAL {count_sql} {granularity.value}", + bind_parameter_set=SqlBindParameterSet.merge_iterable( + (arg_rendered.bind_parameter_set, count_rendered.bind_parameter_set) + ), ) @override diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index 10e3d748b9..a89dc2abba 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -15,7 +15,10 @@ from metricflow_semantics.sql.sql_exprs import ( SqlAddTimeExpression, SqlAggregateFunctionExpression, + SqlArithmeticExpression, + SqlArithmeticOperator, SqlBetweenExpression, + SqlCaseExpression, SqlCastToTimestampExpression, SqlColumnAliasReferenceExpression, SqlColumnReferenceExpression, @@ -26,6 +29,7 @@ SqlExtractExpression, SqlFunction, SqlGenerateUuidExpression, + SqlIntegerExpression, SqlIsNullExpression, SqlLogicalExpression, SqlNullExpression, @@ -320,17 +324,25 @@ def visit_subtract_time_interval_expr( # noqa: D102 ) def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: # noqa: D102 - arg_rendered = node.arg.accept(self) - count_rendered = node.count_expr.accept(self).sql - granularity = node.granularity + count_expr = node.count_expr if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH - count_rendered = f"({count_rendered} * 3)" + count_expr = SqlArithmeticExpression.create( + left_expr=node.count_expr, + operator=SqlArithmeticOperator.MULTIPLY, + right_expr=SqlIntegerExpression.create(3), + ) + + arg_rendered = node.arg.accept(self) + count_rendered = count_expr.accept(self) + count_sql = f"({count_rendered.sql})" if count_expr.requires_parenthesis else count_rendered.sql return SqlExpressionRenderResult( - sql=f"DATEADD({granularity.value}, {count_rendered}, {arg_rendered.sql})", - bind_parameter_set=arg_rendered.bind_parameter_set, + sql=f"DATEADD({granularity.value}, {count_sql}, {arg_rendered.sql})", + bind_parameter_set=SqlBindParameterSet.merge_iterable( + (arg_rendered.bind_parameter_set, count_rendered.bind_parameter_set) + ), ) def visit_ratio_computation_expr(self, node: SqlRatioComputationExpression) -> SqlExpressionRenderResult: @@ -438,3 +450,27 @@ def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpres sql="UUID()", bind_parameter_set=SqlBindParameterSet(), ) + + def visit_case_expr(self, node: SqlCaseExpression) -> SqlExpressionRenderResult: # noqa: D102 + sql = "CASE\n" + for when, then in node.when_to_then_exprs.items(): + sql += indent( + f"WHEN {self.render_sql_expr(when).sql}\n", indent_prefix=SqlRenderingConstants.INDENT + ) + indent( + f"THEN {self.render_sql_expr(then).sql}\n", + indent_prefix=SqlRenderingConstants.INDENT * 2, + ) + if node.else_expr: + sql += indent( + f"ELSE {self.render_sql_expr(node.else_expr).sql}\n", + indent_prefix=SqlRenderingConstants.INDENT, + ) + sql += "END" + return SqlExpressionRenderResult(sql=sql, bind_parameter_set=SqlBindParameterSet()) + + def visit_arithmetic_expr(self, node: SqlArithmeticExpression) -> SqlExpressionRenderResult: # noqa: D102 + sql = f"{self.render_sql_expr(node.left_expr).sql} {node.operator.value} {self.render_sql_expr(node.right_expr).sql}" + return SqlExpressionRenderResult(sql=sql, bind_parameter_set=SqlBindParameterSet()) + + def visit_integer_expr(self, node: SqlIntegerExpression) -> SqlExpressionRenderResult: # noqa: D102 + return SqlExpressionRenderResult(sql=str(node.integer_value), bind_parameter_set=SqlBindParameterSet()) diff --git a/metricflow/sql/render/postgres.py b/metricflow/sql/render/postgres.py index 2509dfc243..f5a4a88581 100644 --- a/metricflow/sql/render/postgres.py +++ b/metricflow/sql/render/postgres.py @@ -8,7 +8,10 @@ from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.sql.sql_exprs import ( SqlAddTimeExpression, + SqlArithmeticExpression, + SqlArithmeticOperator, SqlGenerateUuidExpression, + SqlIntegerExpression, SqlPercentileExpression, SqlPercentileFunctionType, SqlSubtractTimeIntervalExpression, @@ -58,17 +61,25 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress @override def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: """Render time delta operations for PostgreSQL, which needs custom support for quarterly granularity.""" - arg_rendered = node.arg.accept(self) - count_rendered = node.count_expr.accept(self).sql - granularity = node.granularity + count_expr = node.count_expr if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH - count_rendered = f"({count_rendered} * 3)" + SqlArithmeticExpression.create( + left_expr=node.count_expr, + operator=SqlArithmeticOperator.MULTIPLY, + right_expr=SqlIntegerExpression.create(3), + ) + + arg_rendered = node.arg.accept(self) + count_rendered = count_expr.accept(self) + count_sql = f"({count_rendered.sql})" if count_expr.requires_parenthesis else count_rendered.sql return SqlExpressionRenderResult( - sql=f"{arg_rendered.sql} + MAKE_INTERVAL({granularity.value}s => {count_rendered})", - bind_parameter_set=arg_rendered.bind_parameter_set, + sql=f"{arg_rendered.sql} + MAKE_INTERVAL({granularity.value}s => CAST ({count_sql} AS INTEGER))", + bind_parameter_set=SqlBindParameterSet.merge_iterable( + (arg_rendered.bind_parameter_set, count_rendered.bind_parameter_set) + ), ) @override diff --git a/metricflow/sql/render/trino.py b/metricflow/sql/render/trino.py index bd3a581597..f0a8ea2da4 100644 --- a/metricflow/sql/render/trino.py +++ b/metricflow/sql/render/trino.py @@ -9,8 +9,11 @@ from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.sql.sql_exprs import ( SqlAddTimeExpression, + SqlArithmeticExpression, + SqlArithmeticOperator, SqlBetweenExpression, SqlGenerateUuidExpression, + SqlIntegerExpression, SqlPercentileExpression, SqlPercentileFunctionType, SqlSubtractTimeIntervalExpression, @@ -63,17 +66,25 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress @override def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: """Render time delta for Trino, require granularity in quotes and function name change.""" - arg_rendered = node.arg.accept(self) - count_rendered = node.count_expr.accept(self).sql - granularity = node.granularity + count_expr = node.count_expr if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH - count_rendered = f"({count_rendered} * 3)" + SqlArithmeticExpression.create( + left_expr=node.count_expr, + operator=SqlArithmeticOperator.MULTIPLY, + right_expr=SqlIntegerExpression.create(3), + ) + + arg_rendered = node.arg.accept(self) + count_rendered = count_expr.accept(self) + count_sql = f"({count_rendered.sql})" if count_expr.requires_parenthesis else count_rendered.sql return SqlExpressionRenderResult( - sql=f"DATE_ADD('{granularity.value}', {count_rendered}, {arg_rendered.sql})", - bind_parameter_set=arg_rendered.bind_parameter_set, + sql=f"DATE_ADD('{granularity.value}', {count_sql}, {arg_rendered.sql})", + bind_parameter_set=SqlBindParameterSet.merge_iterable( + (arg_rendered.bind_parameter_set, count_rendered.bind_parameter_set) + ), ) @override From 9c1544a82cfb0176bdaeeb348b2985097a15a864 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 18 Dec 2024 09:44:54 -0800 Subject: [PATCH 3/4] Update SQL engine snapshots for add time expr test Not totally relevant to this PR, but these seem to have been removed somehow. --- .../SqlPlan/BigQuery/test_add_time_expr__plan0.sql | 10 ++++++++++ .../SqlPlan/Databricks/test_add_time_expr__plan0.sql | 10 ++++++++++ .../SqlPlan/DuckDB/test_add_time_expr__plan0.sql | 10 ++++++++++ .../SqlPlan/Postgres/test_add_time_expr__plan0.sql | 10 ++++++++++ .../SqlPlan/Redshift/test_add_time_expr__plan0.sql | 10 ++++++++++ .../SqlPlan/Snowflake/test_add_time_expr__plan0.sql | 10 ++++++++++ .../SqlPlan/Trino/test_add_time_expr__plan0.sql | 10 ++++++++++ 7 files changed, 70 insertions(+) create mode 100644 tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/BigQuery/test_add_time_expr__plan0.sql create mode 100644 tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Databricks/test_add_time_expr__plan0.sql create mode 100644 tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/DuckDB/test_add_time_expr__plan0.sql create mode 100644 tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Postgres/test_add_time_expr__plan0.sql create mode 100644 tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Redshift/test_add_time_expr__plan0.sql create mode 100644 tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Snowflake/test_add_time_expr__plan0.sql create mode 100644 tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Trino/test_add_time_expr__plan0.sql diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/BigQuery/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/BigQuery/test_add_time_expr__plan0.sql new file mode 100644 index 0000000000..90fc09ace7 --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/BigQuery/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: BigQuery +--- +-- Test Add Time Expression +SELECT + DATE_ADD(CAST('2020-01-01' AS DATETIME), INTERVAL SqlExpressionRenderResult(sql='1', bind_parameter_set=SqlBindParameterSet(param_items=())) quarter) AS add_time +FROM foo.bar a diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Databricks/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Databricks/test_add_time_expr__plan0.sql new file mode 100644 index 0000000000..3f6f5e12d4 --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Databricks/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: Databricks +--- +-- Test Add Time Expression +SELECT + DATEADD(month, (1 * 3), '2020-01-01') AS add_time +FROM foo.bar a diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/DuckDB/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/DuckDB/test_add_time_expr__plan0.sql new file mode 100644 index 0000000000..984e2096f6 --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/DuckDB/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: DuckDB +--- +-- Test Add Time Expression +SELECT + '2020-01-01' + INTERVAL (1 * 3) month AS add_time +FROM foo.bar a diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Postgres/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Postgres/test_add_time_expr__plan0.sql new file mode 100644 index 0000000000..3bac32e019 --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Postgres/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: Postgres +--- +-- Test Add Time Expression +SELECT + '2020-01-01' + MAKE_INTERVAL(months => (1)) AS add_time +FROM foo.bar a diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Redshift/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Redshift/test_add_time_expr__plan0.sql new file mode 100644 index 0000000000..bac4dc733c --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Redshift/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: Redshift +--- +-- Test Add Time Expression +SELECT + DATEADD(month, (1 * 3), '2020-01-01') AS add_time +FROM foo.bar a diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Snowflake/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Snowflake/test_add_time_expr__plan0.sql new file mode 100644 index 0000000000..b83e173387 --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Snowflake/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: Snowflake +--- +-- Test Add Time Expression +SELECT + DATEADD(month, (1 * 3), '2020-01-01') AS add_time +FROM foo.bar a diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Trino/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Trino/test_add_time_expr__plan0.sql new file mode 100644 index 0000000000..61e35c5e04 --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlPlan/Trino/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: Trino +--- +-- Test Add Time Expression +SELECT + DATE_ADD('month', (1), '2020-01-01') AS add_time +FROM foo.bar a From 73fb43b791a321ea8deadf2d35386ffa9ec33a84 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 18 Dec 2024 15:12:42 -0800 Subject: [PATCH 4/4] Handle window function frame cause appropriately --- .../metricflow_semantics/sql/sql_exprs.py | 14 ++++++++++++++ metricflow/sql/render/expr_renderer.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py b/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py index 926ea5a397..7d432661fe 100644 --- a/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py +++ b/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py @@ -1034,6 +1034,20 @@ def requires_ordering(self) -> bool: else: assert_values_exhausted(self) + @property + def allows_frame_clause(self) -> bool: + """Whether the function allows a frame clause, e.g., 'ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING'.""" + if ( + self is SqlWindowFunction.FIRST_VALUE + or self is SqlWindowFunction.LAST_VALUE + or self is SqlWindowFunction.AVERAGE + ): + return True + if self is SqlWindowFunction.ROW_NUMBER or self is SqlWindowFunction.LAG: + return False + else: + assert_values_exhausted(self) + @classmethod def get_window_function_for_period_agg(cls, period_agg: PeriodAggregation) -> SqlWindowFunction: """Get the window function to use for given period agg option.""" diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index a89dc2abba..158c074ed0 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -428,7 +428,7 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> SqlEx ) ) - if len(order_by_args_rendered) > 0: + if len(order_by_args_rendered) > 0 and node.sql_function.allows_frame_clause: window_string_lines.append("ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING") window_string = "\n".join(window_string_lines)