Skip to content

Commit

Permalink
Update JoinToTimeSpineNode to allow non-metric time dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Oct 31, 2023
1 parent 7d3e57b commit f2d9980
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 31 deletions.
72 changes: 44 additions & 28 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def _next_unique_table_alias(self) -> str:

def _make_time_spine_data_set(
self,
metric_time_dimension_instance: TimeDimensionInstance,
metric_time_dimension_column_name: str,
time_dimension_instance: TimeDimensionInstance,
time_dimension_column_name: str,
time_spine_source: TimeSpineSource,
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> SqlDataSet:
Expand All @@ -176,14 +176,14 @@ def _make_time_spine_data_set(
"""
time_spine_instance = (
TimeDimensionInstance(
defined_from=metric_time_dimension_instance.defined_from,
defined_from=time_dimension_instance.defined_from,
associated_columns=(
ColumnAssociation(
column_name=metric_time_dimension_column_name,
column_name=time_dimension_column_name,
single_column_correlation_key=SingleColumnCorrelationKey(),
),
),
spec=metric_time_dimension_instance.spec,
spec=time_dimension_instance.spec,
),
)
time_spine_instance_set = InstanceSet(
Expand All @@ -193,7 +193,7 @@ def _make_time_spine_data_set(
time_spine_table_alias = self._next_unique_table_alias()

# If the requested granularity is the same as the granularity of the spine, do a direct select.
if metric_time_dimension_instance.spec.time_granularity == time_spine_source.time_column_granularity:
if time_dimension_instance.spec.time_granularity == time_spine_source.time_column_granularity:
return SqlDataSet(
instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode(
Expand All @@ -207,7 +207,7 @@ def _make_time_spine_data_set(
column_name=time_spine_source.time_column_name,
),
),
column_alias=metric_time_dimension_column_name,
column_alias=time_dimension_column_name,
),
),
from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table),
Expand All @@ -229,15 +229,15 @@ def _make_time_spine_data_set(
select_columns = (
SqlSelectColumn(
expr=SqlDateTruncExpression(
time_granularity=metric_time_dimension_instance.spec.time_granularity,
time_granularity=time_dimension_instance.spec.time_granularity,
arg=SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=time_spine_table_alias,
column_name=time_spine_source.time_column_name,
),
),
),
column_alias=metric_time_dimension_column_name,
column_alias=time_dimension_column_name,
),
)
return SqlDataSet(
Expand Down Expand Up @@ -294,8 +294,8 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
# Granularity of time_spine column should match granularity of metric_time column from parent dataset.
assert metric_time_dimension_instance
time_spine_data_set = self._make_time_spine_data_set(
metric_time_dimension_instance=metric_time_dimension_instance,
metric_time_dimension_column_name=metric_time_dimension_column_name,
time_dimension_instance=metric_time_dimension_instance,
time_dimension_column_name=metric_time_dimension_column_name,
time_spine_source=self._time_spine_source,
time_range_constraint=node.time_range_constraint,
)
Expand Down Expand Up @@ -1312,25 +1312,40 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
parent_data_set = node.parent_node.accept(self)
parent_alias = self._next_unique_table_alias()

# Build time spine dataset
# Validate time dimensions & choose the one to join time spine onto.
time_dimension_instance_to_join: Optional[TimeDimensionInstance] = None
metric_time_dimension_instance: Optional[TimeDimensionInstance] = None
# If metric_time is requested, choose the one with the lowest granularity.
for instance in parent_data_set.metric_time_dimension_instances:
if len(instance.spec.entity_links) == 0:
# Use the instance with the lowest granularity
if not metric_time_dimension_instance or (
instance.spec.time_granularity < metric_time_dimension_instance.spec.time_granularity
if not metric_time_dimension_instance or (
instance.spec.time_granularity < metric_time_dimension_instance.spec.time_granularity
):
metric_time_dimension_instance = instance
if node.offset_window or node.offset_to_grain:
assert (
metric_time_dimension_instance
), "Can't query offset metric without metric time. Validations should have prevented this."

# If there were no metric_time dimensions requested, choose the time dimension with the lowest granularity.
time_dimension_instance_to_join = metric_time_dimension_instance
if not time_dimension_instance_to_join:
for instance in parent_data_set.instance_set.time_dimension_instances:
if not time_dimension_instance_to_join or (
instance.spec.time_granularity < time_dimension_instance_to_join.spec.time_granularity
):
metric_time_dimension_instance = instance
time_dimension_instance_to_join = instance
assert (
metric_time_dimension_instance
), "Can't query offset metric without a time dimension. Validations should have prevented this."
metric_time_dimension_column_name = self.column_association_resolver.resolve_spec(
metric_time_dimension_instance.spec
time_dimension_instance_to_join # TODO: update validations to prevent this lol
), "Can't join to time spine without a time dimension. Validations should have prevented this."

# Build time spine dataset
time_dimension_column_name = self.column_association_resolver.resolve_spec(
time_dimension_instance_to_join.spec
).column_name
time_spine_alias = self._next_unique_table_alias()
time_spine_dataset = self._make_time_spine_data_set(
metric_time_dimension_instance=metric_time_dimension_instance,
metric_time_dimension_column_name=metric_time_dimension_column_name,
time_dimension_instance=time_dimension_instance_to_join,
time_dimension_column_name=time_dimension_column_name,
time_spine_source=self._time_spine_source,
time_range_constraint=node.time_range_constraint,
)
Expand All @@ -1339,27 +1354,28 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
join_description = SqlQueryPlanJoinBuilder.make_join_to_time_spine_join_description(
node=node,
time_spine_alias=time_spine_alias,
metric_time_dimension_column_name=metric_time_dimension_column_name,
time_dimension_column_name=time_dimension_column_name,
parent_sql_select_node=parent_data_set.sql_select_node,
parent_alias=parent_alias,
join_type=node.join_type,
)

# Use all instances EXCEPT metric_time from parent data set.
non_metric_time_parent_instance_set = InstanceSet(
# Use all instances EXCEPT joined time dimension from parent data set.
parent_instance_set_to_keep = InstanceSet(
measure_instances=parent_data_set.instance_set.measure_instances,
dimension_instances=parent_data_set.instance_set.dimension_instances,
time_dimension_instances=tuple(
time_dimension_instance
for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances
# TODO: replace logic below - remove the one selected. maybe do that logic above
if time_dimension_instance.spec.element_name != DataSet.metric_time_dimension_reference().element_name
),
entity_instances=parent_data_set.instance_set.entity_instances,
metric_instances=parent_data_set.instance_set.metric_instances,
metadata_instances=parent_data_set.instance_set.metadata_instances,
)
non_metric_time_select_columns = create_select_columns_for_instance_sets(
self._column_association_resolver, OrderedDict({parent_alias: non_metric_time_parent_instance_set})
self._column_association_resolver, OrderedDict({parent_alias: parent_instance_set_to_keep})
)

# Use metric_time column from time spine.
Expand Down Expand Up @@ -1417,7 +1433,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
metric_time_instance_set = InstanceSet(time_dimension_instances=tuple(metric_time_dimension_instances))

return SqlDataSet(
instance_set=InstanceSet.merge([metric_time_instance_set, non_metric_time_parent_instance_set]),
instance_set=InstanceSet.merge([metric_time_instance_set, parent_instance_set_to_keep]),
sql_select_node=SqlSelectStatementNode(
description=node.description,
select_columns=tuple(metric_time_select_columns) + non_metric_time_select_columns,
Expand Down
6 changes: 3 additions & 3 deletions metricflow/plan_conversion/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,14 +472,14 @@ def make_cumulative_metric_time_range_join_description(
def make_join_to_time_spine_join_description(
node: JoinToTimeSpineNode,
time_spine_alias: str,
metric_time_dimension_column_name: str,
time_dimension_column_name: str,
parent_sql_select_node: SqlSelectStatementNode,
parent_alias: str,
join_type: SqlJoinType,
) -> SqlJoinDescription:
"""Build join expression used to join a metric to a time spine dataset."""
left_expr: SqlExpressionNode = SqlColumnReferenceExpression(
col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=metric_time_dimension_column_name)
col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=time_dimension_column_name)
)
if node.offset_window:
left_expr = SqlSubtractTimeIntervalExpression(
Expand All @@ -495,7 +495,7 @@ def make_join_to_time_spine_join_description(
left_expr=left_expr,
comparison=SqlComparison.EQUALS,
right_expr=SqlColumnReferenceExpression(
col_ref=SqlColumnReference(table_alias=parent_alias, column_name=metric_time_dimension_column_name)
col_ref=SqlColumnReference(table_alias=parent_alias, column_name=time_dimension_column_name)
),
),
join_type=join_type,
Expand Down

0 comments on commit f2d9980

Please sign in to comment.