diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 21c381b1fc..1f314eaa78 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -164,8 +164,8 @@ def _next_unique_table_alias(self) -> str: def _make_time_spine_data_set( self, - metric_time_dimension_instance: TimeDimensionInstance, - metric_time_dimension_column_name: str, + time_dimension_instance: TimeDimensionInstance, + time_dimension_column_name: str, time_spine_source: TimeSpineSource, time_range_constraint: Optional[TimeRangeConstraint] = None, ) -> SqlDataSet: @@ -176,14 +176,14 @@ def _make_time_spine_data_set( """ time_spine_instance = ( TimeDimensionInstance( - defined_from=metric_time_dimension_instance.defined_from, + defined_from=time_dimension_instance.defined_from, associated_columns=( ColumnAssociation( - column_name=metric_time_dimension_column_name, + column_name=time_dimension_column_name, single_column_correlation_key=SingleColumnCorrelationKey(), ), ), - spec=metric_time_dimension_instance.spec, + spec=time_dimension_instance.spec, ), ) time_spine_instance_set = InstanceSet( @@ -193,7 +193,7 @@ def _make_time_spine_data_set( time_spine_table_alias = self._next_unique_table_alias() # If the requested granularity is the same as the granularity of the spine, do a direct select. - if metric_time_dimension_instance.spec.time_granularity == time_spine_source.time_column_granularity: + if time_dimension_instance.spec.time_granularity == time_spine_source.time_column_granularity: return SqlDataSet( instance_set=time_spine_instance_set, sql_select_node=SqlSelectStatementNode( @@ -207,7 +207,7 @@ def _make_time_spine_data_set( column_name=time_spine_source.time_column_name, ), ), - column_alias=metric_time_dimension_column_name, + column_alias=time_dimension_column_name, ), ), from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table), @@ -229,7 +229,7 @@ def _make_time_spine_data_set( select_columns = ( SqlSelectColumn( expr=SqlDateTruncExpression( - time_granularity=metric_time_dimension_instance.spec.time_granularity, + time_granularity=time_dimension_instance.spec.time_granularity, arg=SqlColumnReferenceExpression( SqlColumnReference( table_alias=time_spine_table_alias, @@ -237,7 +237,7 @@ def _make_time_spine_data_set( ), ), ), - column_alias=metric_time_dimension_column_name, + column_alias=time_dimension_column_name, ), ) return SqlDataSet( @@ -294,8 +294,8 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat # Granularity of time_spine column should match granularity of metric_time column from parent dataset. assert metric_time_dimension_instance time_spine_data_set = self._make_time_spine_data_set( - metric_time_dimension_instance=metric_time_dimension_instance, - metric_time_dimension_column_name=metric_time_dimension_column_name, + time_dimension_instance=metric_time_dimension_instance, + time_dimension_column_name=metric_time_dimension_column_name, time_spine_source=self._time_spine_source, time_range_constraint=node.time_range_constraint, ) @@ -1312,25 +1312,40 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet parent_data_set = node.parent_node.accept(self) parent_alias = self._next_unique_table_alias() - # Build time spine dataset + # Validate time dimensions & choose the one to join time spine onto. + time_dimension_instance_to_join: Optional[TimeDimensionInstance] = None metric_time_dimension_instance: Optional[TimeDimensionInstance] = None + # If metric_time is requested, choose the one with the lowest granularity. for instance in parent_data_set.metric_time_dimension_instances: - if len(instance.spec.entity_links) == 0: - # Use the instance with the lowest granularity - if not metric_time_dimension_instance or ( - instance.spec.time_granularity < metric_time_dimension_instance.spec.time_granularity + if not metric_time_dimension_instance or ( + instance.spec.time_granularity < metric_time_dimension_instance.spec.time_granularity + ): + metric_time_dimension_instance = instance + if node.offset_window or node.offset_to_grain: + assert ( + metric_time_dimension_instance + ), "Can't query offset metric without metric time. Validations should have prevented this." + + # If there were no metric_time dimensions requested, choose the time dimension with the lowest granularity. + time_dimension_instance_to_join = metric_time_dimension_instance + if not time_dimension_instance_to_join: + for instance in parent_data_set.instance_set.time_dimension_instances: + if not time_dimension_instance_to_join or ( + instance.spec.time_granularity < time_dimension_instance_to_join.spec.time_granularity ): - metric_time_dimension_instance = instance + time_dimension_instance_to_join = instance assert ( - metric_time_dimension_instance - ), "Can't query offset metric without a time dimension. Validations should have prevented this." - metric_time_dimension_column_name = self.column_association_resolver.resolve_spec( - metric_time_dimension_instance.spec + time_dimension_instance_to_join # TODO: update validations to prevent this lol + ), "Can't join to time spine without a time dimension. Validations should have prevented this." + + # Build time spine dataset + time_dimension_column_name = self.column_association_resolver.resolve_spec( + time_dimension_instance_to_join.spec ).column_name time_spine_alias = self._next_unique_table_alias() time_spine_dataset = self._make_time_spine_data_set( - metric_time_dimension_instance=metric_time_dimension_instance, - metric_time_dimension_column_name=metric_time_dimension_column_name, + time_dimension_instance=time_dimension_instance_to_join, + time_dimension_column_name=time_dimension_column_name, time_spine_source=self._time_spine_source, time_range_constraint=node.time_range_constraint, ) @@ -1339,19 +1354,20 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet join_description = SqlQueryPlanJoinBuilder.make_join_to_time_spine_join_description( node=node, time_spine_alias=time_spine_alias, - metric_time_dimension_column_name=metric_time_dimension_column_name, + time_dimension_column_name=time_dimension_column_name, parent_sql_select_node=parent_data_set.sql_select_node, parent_alias=parent_alias, join_type=node.join_type, ) - # Use all instances EXCEPT metric_time from parent data set. - non_metric_time_parent_instance_set = InstanceSet( + # Use all instances EXCEPT joined time dimension from parent data set. + parent_instance_set_to_keep = InstanceSet( measure_instances=parent_data_set.instance_set.measure_instances, dimension_instances=parent_data_set.instance_set.dimension_instances, time_dimension_instances=tuple( time_dimension_instance for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances + # TODO: replace logic below - remove the one selected. maybe do that logic above if time_dimension_instance.spec.element_name != DataSet.metric_time_dimension_reference().element_name ), entity_instances=parent_data_set.instance_set.entity_instances, @@ -1359,7 +1375,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet metadata_instances=parent_data_set.instance_set.metadata_instances, ) non_metric_time_select_columns = create_select_columns_for_instance_sets( - self._column_association_resolver, OrderedDict({parent_alias: non_metric_time_parent_instance_set}) + self._column_association_resolver, OrderedDict({parent_alias: parent_instance_set_to_keep}) ) # Use metric_time column from time spine. @@ -1417,7 +1433,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet metric_time_instance_set = InstanceSet(time_dimension_instances=tuple(metric_time_dimension_instances)) return SqlDataSet( - instance_set=InstanceSet.merge([metric_time_instance_set, non_metric_time_parent_instance_set]), + instance_set=InstanceSet.merge([metric_time_instance_set, parent_instance_set_to_keep]), sql_select_node=SqlSelectStatementNode( description=node.description, select_columns=tuple(metric_time_select_columns) + non_metric_time_select_columns, diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index 26b6f5223f..ee8c5ae00d 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -472,14 +472,14 @@ def make_cumulative_metric_time_range_join_description( def make_join_to_time_spine_join_description( node: JoinToTimeSpineNode, time_spine_alias: str, - metric_time_dimension_column_name: str, + time_dimension_column_name: str, parent_sql_select_node: SqlSelectStatementNode, parent_alias: str, join_type: SqlJoinType, ) -> SqlJoinDescription: """Build join expression used to join a metric to a time spine dataset.""" left_expr: SqlExpressionNode = SqlColumnReferenceExpression( - col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=metric_time_dimension_column_name) + col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=time_dimension_column_name) ) if node.offset_window: left_expr = SqlSubtractTimeIntervalExpression( @@ -495,7 +495,7 @@ def make_join_to_time_spine_join_description( left_expr=left_expr, comparison=SqlComparison.EQUALS, right_expr=SqlColumnReferenceExpression( - col_ref=SqlColumnReference(table_alias=parent_alias, column_name=metric_time_dimension_column_name) + col_ref=SqlColumnReference(table_alias=parent_alias, column_name=time_dimension_column_name) ), ), join_type=join_type,