Skip to content

Commit

Permalink
Dedupe time spine column-building logic & simplify code
Browse files Browse the repository at this point in the history
The biggest change in this commit is to remove all logic related to building select columns from visit_join_to_time_spine_node(). This logic was duplicated in _make_time_spine_data_set().
That logic is now consolidated to _make_time_spine_data_set(). This is laying the foundation for a change coming up the stack, which will replace _make_time_spine_data_set() with its own
node in the DataflowPlan. There is also some other cleanup in this commit because to make visit_join_to_time_spine_node() simpler and more readable.
  • Loading branch information
courtneyholcomb committed Dec 9, 2024
1 parent ac16bbe commit 5a58307
Showing 1 changed file with 67 additions and 120 deletions.
187 changes: 67 additions & 120 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datetime as dt
import logging
from collections import OrderedDict
from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar, Union
from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType
Expand Down Expand Up @@ -39,6 +39,7 @@
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.time.granularity import ExpandedTimeGranularity
from metricflow_semantics.time.time_constants import ISO8601_PYTHON_FORMAT, ISO8601_PYTHON_TS_FORMAT
from metricflow_semantics.time.time_spine_source import TIME_SPINE_DATA_SET_DESCRIPTION, TimeSpineSource
from typing_extensions import override
Expand Down Expand Up @@ -324,54 +325,54 @@ def _make_time_spine_data_set(
"""
time_spine_table_alias = self._next_unique_table_alias()

agg_time_dimension_specs = {instance.spec for instance in agg_time_dimension_instances}
queried_specs = {instance.spec for instance in agg_time_dimension_instances}
specs_required_for_where_constraints = {
spec
for constraint in time_spine_where_constraints
for spec in constraint.linkable_spec_set.time_dimension_specs
}
required_time_spine_specs = sorted( # sorted for consistency in snapshots
agg_time_dimension_specs.union(specs_required_for_where_constraints),
required_specs = sorted( # sorted for consistency in snapshots
queried_specs.union(specs_required_for_where_constraints),
key=lambda spec: (spec.element_name, spec.time_granularity.base_granularity.to_int()),
)
time_spine_sources = TimeSpineSource.choose_time_spine_sources(
required_time_spine_specs=list(required_time_spine_specs), time_spine_sources=self._time_spine_sources
required_time_spine_specs=required_specs, time_spine_sources=self._time_spine_sources
)
# TODO: handle multiple time spine joins
assert len(time_spine_sources) == 1, (
"Join to time spine with custom granularity currently only supports one custom granularity per query. "
"Full feature coming soon."
)
time_spine_source = time_spine_sources[0]
time_spine_base_granularity = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity)

base_column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_table_alias, column_name=time_spine_source.base_column
)
select_columns: Tuple[SqlSelectColumn, ...] = ()
apply_group_by = True
for agg_time_dimension_spec in required_time_spine_specs:
for agg_time_dimension_spec in required_specs:
column_alias = self._column_association_resolver.resolve_spec(agg_time_dimension_spec).column_name
# If the requested granularity is the same as the granularity of the spine, do a direct select.
agg_time_grain = agg_time_dimension_spec.time_granularity
if (
agg_time_grain.base_granularity == time_spine_source.base_granularity
and not agg_time_grain.is_custom_granularity
):
expr: SqlExpressionNode = base_column_expr
apply_group_by = False
# If there is a date_part selected, apply an EXTRACT() to the base column.
if agg_time_dimension_spec.date_part:
expr: SqlExpressionNode = SqlExtractExpression.create(
date_part=agg_time_dimension_spec.date_part, arg=base_column_expr
)
# If the requested granularity is the same as the granularity of the spine, do a direct select.
elif agg_time_grain == time_spine_base_granularity:
expr = base_column_expr
# If the granularity is custom, select the appropriate custom granularity column.
elif agg_time_grain.is_custom_granularity:
# If any dimensions require a custom granularity, select the appropriate column.
for custom_granularity in time_spine_source.custom_granularities:
expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_table_alias, column_name=custom_granularity.parsed_column_name
)
# Otherwise, apply the requested standard granularity using a DATE_TRUNC() on the base column.
else:
# If any dimensions require a different standard granularity, apply a DATE_TRUNC() to the base column.
expr = SqlDateTruncExpression.create(
time_granularity=agg_time_grain.base_granularity, arg=base_column_expr
)
select_columns += (SqlSelectColumn(expr=expr, column_alias=column_alias),)
# TODO: also handle date part.

output_instance_set = InstanceSet(
time_dimension_instances=tuple(
Expand All @@ -385,16 +386,25 @@ def _make_time_spine_data_set(
associated_columns=(self._column_association_resolver.resolve_spec(spec),),
spec=spec,
)
for spec in required_time_spine_specs
for spec in queried_specs
]
)
)

# A group by will be needed to ensure unique rows unless the time spine base grain is included.
apply_group_by_in_inner_select_node = all(
spec.time_granularity != time_spine_base_granularity for spec in required_specs
)
apply_group_by_in_outer_select_node = apply_group_by_in_inner_select_node is False and all(
spec.time_granularity != time_spine_base_granularity for spec in queried_specs
)

inner_sql_select_node = SqlSelectStatementNode.create(
description=TIME_SPINE_DATA_SET_DESCRIPTION,
select_columns=select_columns,
from_source=SqlTableNode.create(sql_table=time_spine_source.spine_table),
from_source_alias=time_spine_table_alias,
group_bys=select_columns if apply_group_by else (),
group_bys=select_columns if apply_group_by_in_inner_select_node else (),
where=(
DataflowNodeToSqlSubqueryVisitor._make_time_range_comparison_expr(
table_alias=time_spine_table_alias,
Expand Down Expand Up @@ -429,17 +439,6 @@ def _make_time_spine_data_set(
column_resolver=self._column_association_resolver,
table_alias_to_instance_set=OrderedDict({inner_query_alias: outer_query_output_instance_set}),
)
# After removing unneeded instances from the inner query, a group by will be needed to ensure unique rows if the
# smallest grain from the inner query was removed.
smallest_grain_in_inner_query = sorted(
output_instance_set.time_dimension_instances,
key=lambda instance: instance.spec.time_granularity.base_granularity.to_int(),
)[0].spec.time_granularity.base_granularity
apply_group_by = True
for instance in outer_query_output_instance_set.time_dimension_instances:
if instance.spec.time_granularity.base_granularity == smallest_grain_in_inner_query:
apply_group_by = False
break

return SqlDataSet(
instance_set=output_instance_set,
Expand All @@ -449,7 +448,7 @@ def _make_time_spine_data_set(
from_source=inner_sql_select_node,
from_source_alias=inner_query_alias,
where=complete_outer_where_filter,
group_bys=outer_query_select_columns if apply_group_by else (),
group_bys=outer_query_select_columns if apply_group_by_in_outer_select_node else (),
),
)

Expand Down Expand Up @@ -1393,7 +1392,6 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
join_on_time_dimension_sample = included_metric_time_instances[0].spec
else:
join_on_time_dimension_sample = agg_time_dimension_instances[0].spec

agg_time_dimension_instance_for_join = self._choose_instance_for_time_spine_join(
[
instance
Expand All @@ -1402,11 +1400,13 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
and instance.spec.entity_links == join_on_time_dimension_sample.entity_links
]
)
if agg_time_dimension_instance_for_join not in agg_time_dimension_instances:
agg_time_dimension_instances = (agg_time_dimension_instance_for_join,) + agg_time_dimension_instances

# Build time spine data set with just the agg_time_dimension instance needed for the join.
time_spine_alias = self._next_unique_table_alias()
time_spine_dataset = self._make_time_spine_data_set(
agg_time_dimension_instances=(agg_time_dimension_instance_for_join,),
agg_time_dimension_instances=agg_time_dimension_instances,
time_range_constraint=node.time_range_constraint,
time_spine_where_constraints=node.time_spine_filters or (),
)
Expand All @@ -1422,105 +1422,52 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
parent_alias=parent_alias,
)

# Select all instances from the parent data set, EXCEPT agg_time_dimensions.
# The agg_time_dimensions will be selected from the time spine data set.
time_dimensions_to_select_from_parent: Tuple[TimeDimensionInstance, ...] = ()
time_dimensions_to_select_from_time_spine: Tuple[TimeDimensionInstance, ...] = ()
for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances:
if time_dimension_instance in agg_time_dimension_instances:
time_dimensions_to_select_from_time_spine += (time_dimension_instance,)
else:
time_dimensions_to_select_from_parent += (time_dimension_instance,)
parent_instance_set = InstanceSet(
measure_instances=parent_data_set.instance_set.measure_instances,
dimension_instances=parent_data_set.instance_set.dimension_instances,
time_dimension_instances=time_dimensions_to_select_from_parent,
entity_instances=parent_data_set.instance_set.entity_instances,
metric_instances=parent_data_set.instance_set.metric_instances,
metadata_instances=parent_data_set.instance_set.metadata_instances,
)
parent_select_columns = create_simple_select_columns_for_instance_sets(
self._column_association_resolver, OrderedDict({parent_alias: parent_instance_set})
)

original_time_spine_dim_instance = time_spine_dataset.instance_for_time_dimension(
agg_time_dimension_instance_for_join.spec
)
time_spine_column_select_expr: Union[
SqlColumnReferenceExpression, SqlDateTruncExpression
] = SqlColumnReferenceExpression.create(
SqlColumnReference(
table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name
)
)
# Remove time spine instances from parent instance set.
time_spine_instances = time_spine_dataset.instance_set
time_spine_specs = time_spine_instances.spec_set
parent_instance_set = parent_data_set.instance_set.transform(FilterElements(exclude_specs=time_spine_specs))

time_spine_select_columns = []
time_spine_dim_instances = []
where_filter: Optional[SqlExpressionNode] = None
# Build select columns
select_columns = create_simple_select_columns_for_instance_sets(
self._column_association_resolver,
OrderedDict({parent_alias: parent_instance_set, time_spine_alias: time_spine_dataset.instance_set}),
)

# If offset_to_grain is used, will need to filter down to rows that match selected granularities.
# Does not apply if one of the granularities selected matches the time spine column granularity.
where_filter: Optional[SqlExpressionNode] = None
need_where_filter = (
node.offset_to_grain
and original_time_spine_dim_instance.spec not in node.requested_agg_time_dimension_specs
and agg_time_dimension_instance_for_join.spec not in node.requested_agg_time_dimension_specs
)

# Add requested granularities (if different from time_spine) and date_parts to time spine column.
for parent_time_dimension_instance in time_dimensions_to_select_from_time_spine:
time_dimension_spec = parent_time_dimension_instance.spec
if (
time_dimension_spec.time_granularity.base_granularity.to_int()
< original_time_spine_dim_instance.spec.time_granularity.base_granularity.to_int()
):
raise RuntimeError(
f"Can't join to time spine for a time dimension with a smaller granularity than that of the time "
f"spine column. Got {time_dimension_spec.time_granularity} for time dimension, "
f"{original_time_spine_dim_instance.spec.time_granularity} for time spine."
)

# Apply grain to time spine select expression, unless grain already matches original time spine column.
should_skip_date_trunc = (
time_dimension_spec.time_granularity == original_time_spine_dim_instance.spec.time_granularity
or time_dimension_spec.time_granularity.is_custom_granularity
)
select_expr: SqlExpressionNode = (
time_spine_column_select_expr
if should_skip_date_trunc
else SqlDateTruncExpression.create(
time_granularity=time_dimension_spec.time_granularity.base_granularity,
arg=time_spine_column_select_expr,
)
)
# Filter down to one row per granularity period requested in the group by. Any other granularities
# included here will be filtered out in later nodes so should not be included in where filter.
if need_where_filter and time_dimension_spec in node.requested_agg_time_dimension_specs:
new_where_filter = SqlComparisonExpression.create(
left_expr=select_expr, comparison=SqlComparison.EQUALS, right_expr=time_spine_column_select_expr
)
where_filter = (
SqlLogicalExpression.create(operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter))
if where_filter
else new_where_filter
)

# Apply date_part to time spine column select expression.
if time_dimension_spec.date_part:
select_expr = SqlExtractExpression.create(date_part=time_dimension_spec.date_part, arg=select_expr)

time_spine_dim_instance = parent_time_dimension_instance.with_new_defined_from(
original_time_spine_dim_instance.defined_from
if need_where_filter:
join_column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias,
column_name=agg_time_dimension_instance_for_join.associated_column.column_name,
)
time_spine_dim_instances.append(time_spine_dim_instance)
time_spine_select_columns.append(
SqlSelectColumn(expr=select_expr, column_alias=time_spine_dim_instance.associated_column.column_name)
)
time_spine_instance_set = InstanceSet(time_dimension_instances=tuple(time_spine_dim_instances))
for time_spine_instance in time_spine_instances.as_tuple:
# Filter down to one row per granularity period requested in the group by. Any other granularities
# included here will be filtered out in later nodes so should not be included in where filter.
if need_where_filter and time_spine_instance.spec in node.requested_agg_time_dimension_specs:
column_to_filter_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=time_spine_instance.associated_column.column_name
)
new_where_filter = SqlComparisonExpression.create(
left_expr=column_to_filter_expr, comparison=SqlComparison.EQUALS, right_expr=join_column_expr
)
where_filter = (
SqlLogicalExpression.create(
operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter)
)
if where_filter
else new_where_filter
)

return SqlDataSet(
instance_set=InstanceSet.merge([time_spine_instance_set, parent_instance_set]),
instance_set=InstanceSet.merge([time_spine_dataset.instance_set, parent_instance_set]),
sql_select_node=SqlSelectStatementNode.create(
description=node.description,
select_columns=tuple(time_spine_select_columns) + parent_select_columns,
select_columns=select_columns,
from_source=time_spine_dataset.checked_sql_select_node,
from_source_alias=time_spine_alias,
join_descs=(join_description,),
Expand Down

0 comments on commit 5a58307

Please sign in to comment.