diff --git a/metricflow/sql/render/duckdb_renderer.py b/metricflow/sql/render/duckdb_renderer.py index 833bbedcf..3e03b7eca 100644 --- a/metricflow/sql/render/duckdb_renderer.py +++ b/metricflow/sql/render/duckdb_renderer.py @@ -44,7 +44,7 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress count = node.count granularity = node.granularity - if granularity == TimeGranularity.QUARTER: + if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH count *= 3 @@ -60,9 +60,9 @@ def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRender count_rendered = node.count_expr.accept(self).sql granularity = node.granularity - if granularity == TimeGranularity.QUARTER: + if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH - count_rendered = f"{count_rendered} * 3" + count_rendered = f"({count_rendered} * 3)" return SqlExpressionRenderResult( sql=f"{arg_rendered.sql} + INTERVAL {count_rendered} {granularity.value}", diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index c54ecb8ee..a387a2da0 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -311,7 +311,7 @@ def visit_subtract_time_interval_expr( # noqa: D102 count = node.count granularity = node.granularity - if granularity == TimeGranularity.QUARTER: + if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH count *= 3 return SqlExpressionRenderResult( @@ -324,9 +324,9 @@ def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRender count_rendered = node.count_expr.accept(self).sql granularity = node.granularity - if granularity == TimeGranularity.QUARTER: + if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH - count_rendered = f"{count_rendered} * 3" + count_rendered = f"({count_rendered} * 3)" return SqlExpressionRenderResult( sql=f"DATEADD({granularity.value}, {count_rendered}, {arg_rendered.sql})", diff --git a/metricflow/sql/render/postgres.py b/metricflow/sql/render/postgres.py index 565289974..8ced43020 100644 --- a/metricflow/sql/render/postgres.py +++ b/metricflow/sql/render/postgres.py @@ -47,7 +47,7 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress count = node.count granularity = node.granularity - if granularity == TimeGranularity.QUARTER: + if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH count *= 3 return SqlExpressionRenderResult( @@ -62,9 +62,9 @@ def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRender count_rendered = node.count_expr.accept(self).sql granularity = node.granularity - if granularity == TimeGranularity.QUARTER: + if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH - count_rendered = f"{count_rendered} * 3" + count_rendered = f"({count_rendered} * 3)" return SqlExpressionRenderResult( sql=f"{arg_rendered.sql} + MAKE_INTERVAL({granularity.value}s => {count_rendered})", diff --git a/metricflow/sql/render/trino.py b/metricflow/sql/render/trino.py index ac902aa17..e6aff3150 100644 --- a/metricflow/sql/render/trino.py +++ b/metricflow/sql/render/trino.py @@ -52,7 +52,7 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress count = node.count granularity = node.granularity - if granularity == TimeGranularity.QUARTER: + if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH count *= 3 return SqlExpressionRenderResult( @@ -67,9 +67,9 @@ def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRender count_rendered = node.count_expr.accept(self).sql granularity = node.granularity - if granularity == TimeGranularity.QUARTER: + if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH - count_rendered = f"{count_rendered} * 3" + count_rendered = f"({count_rendered} * 3)" return SqlExpressionRenderResult( sql=f"DATE_ADD('{granularity.value}', {count_rendered}, {arg_rendered.sql})", diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/DuckDB/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/DuckDB/test_add_time_expr__plan0.sql new file mode 100644 index 000000000..984e2096f --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/DuckDB/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: DuckDB +--- +-- Test Add Time Expression +SELECT + '2020-01-01' + INTERVAL (1 * 3) month AS add_time +FROM foo.bar a diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Redshift/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Redshift/test_add_time_expr__plan0.sql new file mode 100644 index 000000000..5188a324e --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Redshift/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: Redshift +--- +-- Test Approximate Discrete Percentile Expression +SELECT + DATEADD(month, (1 * 3), '2020-01-01') AS add_time +FROM foo.bar a diff --git a/tests_metricflow/sql/test_engine_specific_rendering.py b/tests_metricflow/sql/test_engine_specific_rendering.py index 60c5a97ca..987762006 100644 --- a/tests_metricflow/sql/test_engine_specific_rendering.py +++ b/tests_metricflow/sql/test_engine_specific_rendering.py @@ -4,11 +4,13 @@ import pytest from _pytest.fixtures import FixtureRequest +from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow.protocols.sql_client import SqlClient from metricflow.sql.sql_exprs import ( + SqlAddTimeExpression, SqlCastToTimestampExpression, SqlColumnReference, SqlColumnReferenceExpression, @@ -16,6 +18,7 @@ SqlPercentileExpression, SqlPercentileExpressionArgument, SqlPercentileFunctionType, + SqlStringExpression, SqlStringLiteralExpression, ) from metricflow.sql.sql_plan import ( @@ -295,3 +298,42 @@ def test_approximate_discrete_percentile_expr( plan_id="plan0", sql_client=sql_client, ) + + +@pytest.mark.sql_engine_snapshot +def test_add_time_expr( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + sql_client: SqlClient, +) -> None: + """Tests rendering of the SqlAddTimeExpr in a query.""" + select_columns = [ + SqlSelectColumn( + expr=SqlAddTimeExpression.create( + arg=SqlStringLiteralExpression.create( + "2020-01-01", + ), + count_expr=SqlStringExpression.create( + "1", + ), + granularity=TimeGranularity.QUARTER, + ), + column_alias="add_time", + ), + ] + + from_source = SqlTableNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source_alias = "a" + + assert_rendered_sql_equal( + request=request, + mf_test_configuration=mf_test_configuration, + sql_plan_node=SqlSelectStatementNode.create( + description="Test Add Time Expression", + select_columns=tuple(select_columns), + from_source=from_source, + from_source_alias=from_source_alias, + ), + plan_id="plan0", + sql_client=sql_client, + )