From 0f3cf0b5dcdcc2eb2ac75517fee09855f0ea7fc9 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Mon, 11 Sep 2023 19:07:40 -0700 Subject: [PATCH] Handle time offset with date_part --- metricflow/plan_conversion/dataflow_to_sql.py | 46 +++++++++++-------- metricflow/query/query_parser.py | 1 + .../integration/test_cases/itest_metrics.yaml | 30 ++++++++++-- .../test/integration/test_configured_cases.py | 19 ++++++++ 4 files changed, 73 insertions(+), 23 deletions(-) diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 6dc0db4c23..84b2e0d261 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -89,6 +89,7 @@ SqlComparisonExpression, SqlDateTruncExpression, SqlExpressionNode, + SqlExtractExpression, SqlFunctionExpression, SqlLogicalExpression, SqlLogicalOperator, @@ -292,7 +293,8 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode[SqlDataSet metric_time_dimension_spec ).column_name - # assemble dataset with metric_time_dimension to join + # Assemble time_spine dataset with metric_time_dimension to join. + # Granularity of time_spine column should match granularity of metric_time column from parent dataset. assert metric_time_dimension_instance time_spine_data_set = self._make_time_spine_data_set( metric_time_dimension_instance=metric_time_dimension_instance, @@ -1354,11 +1356,11 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode[SourceDataSetT len(time_spine_dataset.instance_set.time_dimension_instances) == 1 and len(time_spine_dataset.sql_select_node.select_columns) == 1 ), "Time spine dataset not configured properly. Expected exactly one column." - original_time_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0] + time_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0] time_spine_column_select_expr: Union[ SqlColumnReferenceExpression, SqlDateTruncExpression ] = SqlColumnReferenceExpression( - SqlColumnReference(table_alias=time_spine_alias, column_name=original_time_dim_instance.spec.qualified_name) + SqlColumnReference(table_alias=time_spine_alias, column_name=time_dim_instance.spec.qualified_name) ) # Add requested granularities (skip for default granularity) and date_parts. @@ -1366,28 +1368,15 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode[SourceDataSetT metric_time_dimension_instances = [] where: Optional[SqlExpressionNode] = None for metric_time_dimension_spec in node.metric_time_dimension_specs: + # Apply granularity to SQL. if metric_time_dimension_spec.time_granularity == self._time_spine_source.time_column_granularity: - select_expr = time_spine_column_select_expr - time_dim_instance = original_time_dim_instance - column_alias = original_time_dim_instance.associated_column.column_name + select_expr: SqlExpressionNode = time_spine_column_select_expr else: select_expr = SqlDateTruncExpression( time_granularity=metric_time_dimension_spec.time_granularity, arg=time_spine_column_select_expr ) - new_time_dim_spec = TimeDimensionSpec( - element_name=original_time_dim_instance.spec.element_name, - entity_links=original_time_dim_instance.spec.entity_links, - time_granularity=metric_time_dimension_spec.time_granularity, - date_part=metric_time_dimension_spec.date_part, - aggregation_state=original_time_dim_instance.spec.aggregation_state, - ) - time_dim_instance = TimeDimensionInstance( - defined_from=original_time_dim_instance.defined_from, - associated_columns=(self._column_association_resolver.resolve_spec(new_time_dim_spec),), - spec=new_time_dim_spec, - ) - column_alias = time_dim_instance.associated_column.column_name if node.offset_to_grain: + # TODO: allow offset_to_grain w/ granularity & datepart? what's the expected behavior? # Filter down to one row per granularity period new_filter = SqlComparisonExpression( left_expr=select_expr, comparison=SqlComparison.EQUALS, right_expr=time_spine_column_select_expr @@ -1396,8 +1385,25 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode[SourceDataSetT where = new_filter else: where = SqlLogicalExpression(operator=SqlLogicalOperator.OR, args=(where, new_filter)) + # Apply date_part to SQL. + if metric_time_dimension_spec.date_part: + select_expr = SqlExtractExpression(date_part=metric_time_dimension_spec.date_part, arg=select_expr) + time_dim_spec = TimeDimensionSpec( + element_name=time_dim_instance.spec.element_name, + entity_links=time_dim_instance.spec.entity_links, + time_granularity=metric_time_dimension_spec.time_granularity, + date_part=metric_time_dimension_spec.date_part, + aggregation_state=time_dim_instance.spec.aggregation_state, + ) + time_dim_instance = TimeDimensionInstance( + defined_from=time_dim_instance.defined_from, + associated_columns=(self._column_association_resolver.resolve_spec(time_dim_spec),), + spec=time_dim_spec, + ) metric_time_dimension_instances.append(time_dim_instance) - metric_time_select_columns.append(SqlSelectColumn(expr=select_expr, column_alias=column_alias)) + metric_time_select_columns.append( + SqlSelectColumn(expr=select_expr, column_alias=time_dim_instance.associated_column.column_name) + ) metric_time_instance_set = InstanceSet(time_dimension_instances=tuple(metric_time_dimension_instances)) return SqlDataSet( diff --git a/metricflow/query/query_parser.py b/metricflow/query/query_parser.py index 6f437117ac..956f3933fc 100644 --- a/metricflow/query/query_parser.py +++ b/metricflow/query/query_parser.py @@ -797,6 +797,7 @@ def _parse_order_by( element_name=parsed_name.element_name, entity_links=entity_links, time_granularity=parsed_name.time_granularity, + date_part=parsed_name.date_part, ), descending=descending, ) diff --git a/metricflow/test/integration/test_cases/itest_metrics.yaml b/metricflow/test/integration/test_cases/itest_metrics.yaml index ced9b011a1..f08254aaa1 100644 --- a/metricflow/test/integration/test_cases/itest_metrics.yaml +++ b/metricflow/test/integration/test_cases/itest_metrics.yaml @@ -1053,7 +1053,31 @@ integration_test: check_query: | SELECT SUM(1) AS bookings - , EXTRACT(DAYOFWEEK FROM ds) AS metric_time__extract_dayofweek + , {{ render_extract("ds", DatePart.DAYOFWEEK) }} AS metric_time__extract_dayofweek FROM {{ source_schema }}.fct_bookings - GROUP BY EXTRACT(DAYOFWEEK FROM ds); -# TODO: test with cumulative metric, offset metric, others? + GROUP BY {{ render_extract("ds", DatePart.DAYOFWEEK) }}; +--- +integration_test: + name: derived_metric_offset_window_and_date_part + description: Tests a derived metric offset query with window and date_part + model: SIMPLE_MODEL + metrics: ["bookings_5_day_lag"] + group_bys: ["metric_time__extract_month"] + check_query: | + SELECT + {{ render_extract("a.ds", DatePart.MONTH) }} AS metric_time__extract_month + , SUM(b.bookings_5_day_lag) AS bookings_5_day_lag + FROM {{ mf_time_spine_source }} a + INNER JOIN ( + SELECT + ds AS metric_time__day + , 1 AS bookings_5_day_lag + FROM {{ source_schema }}.fct_bookings + ) b + ON {{ render_date_sub("a", "ds", 5, TimeGranularity.DAY) }} = b.metric_time__day + GROUP BY metric_time__extract_month + +# TODO: +# test each date part syntax with each engine +# dataflow plan tests? +# dataflow to sql tests? diff --git a/metricflow/test/integration/test_configured_cases.py b/metricflow/test/integration/test_configured_cases.py index 12b6aac0f3..cde625fde7 100644 --- a/metricflow/test/integration/test_configured_cases.py +++ b/metricflow/test/integration/test_configured_cases.py @@ -23,6 +23,7 @@ SqlColumnReference, SqlColumnReferenceExpression, SqlDateTruncExpression, + SqlExtractExpression, SqlPercentileExpression, SqlPercentileExpressionArgument, SqlPercentileFunctionType, @@ -39,6 +40,7 @@ from metricflow.test.time.configurable_time_source import ( ConfigurableTimeSource, ) +from metricflow.time.date_part import DatePart logger = logging.getLogger(__name__) @@ -100,6 +102,19 @@ def render_date_trunc(self, expr: str, granularity: TimeGranularity) -> str: ) return self._sql_client.sql_query_plan_renderer.expr_renderer.render_sql_expr(renderable_expr).sql + def render_extract(self, expr: str, date_part: DatePart) -> str: + """Return the EXTRACT call that can be used for converting the given expr to the date_part.""" + renderable_expr = SqlExtractExpression( + date_part=date_part, + arg=SqlCastToTimestampExpression( + arg=SqlStringExpression( + sql_expr=expr, + requires_parenthesis=False, + ) + ), + ) + return self._sql_client.sql_query_plan_renderer.expr_renderer.render_sql_expr(renderable_expr).sql + def render_percentile_expr( self, expr: str, percentile: float, use_discrete_percentile: bool, use_approximate_percentile: bool ) -> str: @@ -252,8 +267,10 @@ def test_case( source_schema=mf_test_session_state.mf_source_schema, render_time_constraint=check_query_helpers.render_time_constraint, TimeGranularity=TimeGranularity, + DatePart=DatePart, render_date_sub=check_query_helpers.render_date_sub, render_date_trunc=check_query_helpers.render_date_trunc, + render_extract=check_query_helpers.render_extract, render_percentile_expr=check_query_helpers.render_percentile_expr, mf_time_spine_source=semantic_manifest_lookup.time_spine_source.spine_table.sql, double_data_type_name=check_query_helpers.double_data_type_name, @@ -277,8 +294,10 @@ def test_case( source_schema=mf_test_session_state.mf_source_schema, render_time_constraint=check_query_helpers.render_time_constraint, TimeGranularity=TimeGranularity, + DatePart=DatePart, render_date_sub=check_query_helpers.render_date_sub, render_date_trunc=check_query_helpers.render_date_trunc, + render_extract=check_query_helpers.render_extract, render_percentile_expr=check_query_helpers.render_percentile_expr, mf_time_spine_source=semantic_manifest_lookup.time_spine_source.spine_table.sql, double_data_type_name=check_query_helpers.double_data_type_name,