Skip to content

Commit

Permalink
Join to Time Spine & Fill Nulls (#832)
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb authored Nov 2, 2023
1 parent 0104913 commit 7bbb704
Show file tree
Hide file tree
Showing 126 changed files with 24,388 additions and 58 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231031-155842.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Join to time spine and fill nulls when requested on metric input measures.
time: 2023-10-31T15:58:42.748645-07:00
custom:
Author: courtneyholcomb
Issue: "759"
22 changes: 20 additions & 2 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,10 +820,11 @@ def _build_aggregated_measures_from_measure_source_node(
assert metric_time_dimension_specs, "Joining to time spine requires querying with metric time."
join_to_time_spine_node = JoinToTimeSpineNode(
parent_node=time_range_node or measure_recipe.source_node,
metric_time_dimension_specs=metric_time_dimension_specs,
requested_metric_time_dimension_specs=metric_time_dimension_specs,
time_range_constraint=time_range_constraint,
offset_window=metric_spec.offset_window,
offset_to_grain=metric_spec.offset_to_grain,
join_type=SqlJoinType.INNER,
)

# Only get the required measure and the local linkable instances so that aggregations work correctly.
Expand Down Expand Up @@ -911,7 +912,24 @@ def _build_aggregated_measures_from_measure_source_node(
(InstanceSpecSet(measure_specs=measure_specs), queried_linkable_specs.as_spec_set)
),
)
return AggregateMeasuresNode(
aggregate_measures_node = AggregateMeasuresNode(
parent_node=pre_aggregate_node,
metric_input_measure_specs=tuple(metric_input_measure_specs),
)

join_aggregated_measure_to_time_spine = False
for metric_input_measure in metric_input_measure_specs:
if metric_input_measure.join_to_timespine:
join_aggregated_measure_to_time_spine = True
break

# Only join to time spine if metric time was requested in the query.
if join_aggregated_measure_to_time_spine and metric_time_dimension_requested:
return JoinToTimeSpineNode(
parent_node=aggregate_measures_node,
requested_metric_time_dimension_specs=metric_time_dimension_specs,
time_range_constraint=time_range_constraint,
join_type=SqlJoinType.LEFT_OUTER,
)
else:
return aggregate_measures_node
25 changes: 18 additions & 7 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,8 @@ class JoinToTimeSpineNode(BaseOutput, ABC):
def __init__(
self,
parent_node: BaseOutput,
metric_time_dimension_specs: List[TimeDimensionSpec],
requested_metric_time_dimension_specs: List[TimeDimensionSpec],
join_type: SqlJoinType,
time_range_constraint: Optional[TimeRangeConstraint] = None,
offset_window: Optional[MetricTimeWindow] = None,
offset_to_grain: Optional[TimeGranularity] = None,
Expand All @@ -709,7 +710,7 @@ def __init__(
Args:
parent_node: Node that returns desired dataset to join to time spine.
metric_time_dimension_specs: Metric time dimensions requested in query. Used to determine granularities.
requested_metric_time_dimension_specs: Time dimensions requested in query. Used to determine granularities.
time_range_constraint: Time range to constrain the time spine to.
offset_window: Time window to offset the parent dataset by when joining to time spine.
offset_to_grain: Granularity period to offset the parent dataset to when joining to time spine.
Expand All @@ -720,10 +721,11 @@ def __init__(
offset_window and offset_to_grain
), "Can't set both offset_window and offset_to_grain when joining to time spine. Choose one or the other."
self._parent_node = parent_node
self._metric_time_dimension_specs = metric_time_dimension_specs
self._requested_metric_time_dimension_specs = requested_metric_time_dimension_specs
self._offset_window = offset_window
self._offset_to_grain = offset_to_grain
self._time_range_constraint = time_range_constraint
self._join_type = join_type

super().__init__(node_id=self.create_unique_id(), parent_nodes=[self._parent_node])

Expand All @@ -732,9 +734,9 @@ def id_prefix(cls) -> str: # noqa: D
return DATAFLOW_NODE_JOIN_TO_TIME_SPINE_ID_PREFIX

@property
def metric_time_dimension_specs(self) -> List[TimeDimensionSpec]: # noqa: D
def requested_metric_time_dimension_specs(self) -> List[TimeDimensionSpec]: # noqa: D
"""Time dimension specs to use when creating time spine table."""
return self._metric_time_dimension_specs
return self._requested_metric_time_dimension_specs

@property
def time_range_constraint(self) -> Optional[TimeRangeConstraint]: # noqa: D
Expand All @@ -751,6 +753,11 @@ def offset_to_grain(self) -> Optional[TimeGranularity]: # noqa: D
"""Time range constraint to apply when querying time spine table."""
return self._offset_to_grain

@property
def join_type(self) -> SqlJoinType: # noqa: D
"""Join type to use when joining to time spine."""
return self._join_type

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D
return visitor.visit_join_to_time_spine_node(self)

Expand All @@ -761,9 +768,11 @@ def description(self) -> str: # noqa: D
@property
def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D
return super().displayed_properties + [
DisplayedProperty("requested_metric_time_dimension_specs", self._requested_metric_time_dimension_specs),
DisplayedProperty("time_range_constraint", self._time_range_constraint),
DisplayedProperty("offset_window", self._offset_window),
DisplayedProperty("offset_to_grain", self._offset_to_grain),
DisplayedProperty("join_type", self._join_type),
]

@property
Expand All @@ -776,17 +785,19 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa:
and other_node.time_range_constraint == self.time_range_constraint
and other_node.offset_window == self.offset_window
and other_node.offset_to_grain == self.offset_to_grain
and other_node.metric_time_dimension_specs == self.metric_time_dimension_specs
and other_node.requested_metric_time_dimension_specs == self.requested_metric_time_dimension_specs
and other_node.join_type == self.join_type
)

def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> JoinToTimeSpineNode: # noqa: D
assert len(new_parent_nodes) == 1
return JoinToTimeSpineNode(
parent_node=new_parent_nodes[0],
metric_time_dimension_specs=self.metric_time_dimension_specs,
requested_metric_time_dimension_specs=self.requested_metric_time_dimension_specs,
time_range_constraint=self.time_range_constraint,
offset_window=self.offset_window,
offset_to_grain=self.offset_to_grain,
join_type=self.join_type,
)


Expand Down
2 changes: 2 additions & 0 deletions metricflow/model/semantics/metric_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def measures_for_metric(
column_association_resolver=column_association_resolver,
).create_from_where_filter_intersection(input_measure.filter),
alias=input_measure.alias,
join_to_timespine=input_measure.join_to_timespine,
fill_nulls_with=input_measure.fill_nulls_with,
)
input_measure_specs.append(spec)

Expand Down
101 changes: 56 additions & 45 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Optional, Sequence, Union

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols.metric import MetricType
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType
from dbt_semantic_interfaces.references import MetricModelReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType

Expand Down Expand Up @@ -81,6 +81,7 @@
SqlQueryOptimizerConfiguration,
)
from metricflow.sql.sql_exprs import (
SqlAggregateFunctionExpression,
SqlBetweenExpression,
SqlColumnReference,
SqlColumnReferenceExpression,
Expand All @@ -89,6 +90,7 @@
SqlDateTruncExpression,
SqlExpressionNode,
SqlExtractExpression,
SqlFunction,
SqlFunctionExpression,
SqlLogicalExpression,
SqlLogicalOperator,
Expand Down Expand Up @@ -633,6 +635,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
metric = self._metric_lookup.get_metric(metric_spec.as_reference)

metric_expr: Optional[SqlExpressionNode] = None
input_measure: Optional[MetricInputMeasure] = None
if metric.type is MetricType.RATIO:
numerator = metric.type_params.numerator
denominator = metric.type_params.denominator
Expand Down Expand Up @@ -664,33 +667,26 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
if len(metric.input_measures) > 0:
assert (
len(metric.input_measures) == 1
), "Measure proxy metrics should always source from exactly 1 measure."
), "Simple metrics should always source from exactly 1 measure."
input_measure = metric.input_measures[0]
expr = self._column_association_resolver.resolve_spec(
MeasureSpec(
element_name=metric.input_measures[0].post_aggregation_measure_reference.element_name
)
MeasureSpec(element_name=input_measure.post_aggregation_measure_reference.element_name)
).column_name
else:
expr = metric.name
# Use a column reference to improve query optimization.
metric_expr = SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=from_data_set_alias,
column_name=expr,
)
metric_expr = self.__make_col_reference_or_coalesce_expr(
column_name=expr, input_measure=input_measure, from_data_set_alias=from_data_set_alias
)
elif metric.type is MetricType.CUMULATIVE:
assert (
len(metric.measure_references) == 1
), "Cumulative metrics should always source from exactly 1 measure."
input_measure = metric.input_measures[0]
expr = self._column_association_resolver.resolve_spec(
MeasureSpec(element_name=metric.input_measures[0].post_aggregation_measure_reference.element_name)
MeasureSpec(element_name=input_measure.post_aggregation_measure_reference.element_name)
).column_name
metric_expr = SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=from_data_set_alias,
column_name=expr,
)
metric_expr = self.__make_col_reference_or_coalesce_expr(
column_name=expr, input_measure=input_measure, from_data_set_alias=from_data_set_alias
)
elif metric.type is MetricType.DERIVED:
assert metric.type_params.expr
Expand Down Expand Up @@ -734,6 +730,21 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
),
)

def __make_col_reference_or_coalesce_expr(
self, column_name: str, input_measure: Optional[MetricInputMeasure], from_data_set_alias: str
) -> SqlExpressionNode:
# Use a column reference to improve query optimization.
metric_expr: SqlExpressionNode = SqlColumnReferenceExpression(
SqlColumnReference(table_alias=from_data_set_alias, column_name=column_name)
)
# Coalesce nulls to requested integer value, if requested.
if input_measure and input_measure.fill_nulls_with is not None:
metric_expr = SqlAggregateFunctionExpression(
sql_function=SqlFunction.COALESCE,
sql_function_args=[metric_expr, SqlStringExpression(str(input_measure.fill_nulls_with))],
)
return metric_expr

def visit_order_by_limit_node(self, node: OrderByLimitNode) -> SqlDataSet: # noqa: D
from_data_set: SqlDataSet = node.parent_node.accept(self)
output_instance_set = from_data_set.instance_set
Expand Down Expand Up @@ -1312,7 +1323,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
metric_time_dimension_instance = instance
assert (
metric_time_dimension_instance
), "Can't query offset metric without a time dimension. Validations should have prevented this."
), "Can't join to time spine without metric time. Validations should have prevented this."
metric_time_dimension_column_name = self.column_association_resolver.resolve_spec(
metric_time_dimension_instance.spec
).column_name
Expand Down Expand Up @@ -1346,33 +1357,33 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
metric_instances=parent_data_set.instance_set.metric_instances,
metadata_instances=parent_data_set.instance_set.metadata_instances,
)
non_metric_time_select_columns = create_select_columns_for_instance_sets(
parent_select_columns = create_select_columns_for_instance_sets(
self._column_association_resolver, OrderedDict({parent_alias: non_metric_time_parent_instance_set})
)

# Use metric_time column from time spine.
# Use time instance from time spine to replace metric_time instances.
assert (
len(time_spine_dataset.instance_set.time_dimension_instances) == 1
and len(time_spine_dataset.sql_select_node.select_columns) == 1
), "Time spine dataset not configured properly. Expected exactly one column."
time_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0]
time_spine_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0]
time_spine_column_select_expr: Union[
SqlColumnReferenceExpression, SqlDateTruncExpression
] = SqlColumnReferenceExpression(
SqlColumnReference(table_alias=time_spine_alias, column_name=time_dim_instance.spec.qualified_name)
SqlColumnReference(table_alias=time_spine_alias, column_name=time_spine_dim_instance.spec.qualified_name)
)

# Add requested granularities (skip for default granularity) and date_parts.
metric_time_select_columns = []
metric_time_dimension_instances = []
# Add requested granularities (if different from time_spine) and date_parts to time spine column.
time_spine_select_columns = []
time_spine_dim_instances = []
where: Optional[SqlExpressionNode] = None
for metric_time_dimension_spec in node.metric_time_dimension_specs:
# Apply granularity to SQL.
if metric_time_dimension_spec.time_granularity == self._time_spine_source.time_column_granularity:
for requested_time_dimension_spec in node.requested_metric_time_dimension_specs:
# Apply granularity to time spine column select expression.
if requested_time_dimension_spec.time_granularity == time_spine_dim_instance.spec.time_granularity:
select_expr: SqlExpressionNode = time_spine_column_select_expr
else:
select_expr = SqlDateTruncExpression(
time_granularity=metric_time_dimension_spec.time_granularity, arg=time_spine_column_select_expr
time_granularity=requested_time_dimension_spec.time_granularity, arg=time_spine_column_select_expr
)
if node.offset_to_grain:
# Filter down to one row per granularity period
Expand All @@ -1383,32 +1394,32 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
where = new_filter
else:
where = SqlLogicalExpression(operator=SqlLogicalOperator.OR, args=(where, new_filter))
# Apply date_part to SQL.
if metric_time_dimension_spec.date_part:
select_expr = SqlExtractExpression(date_part=metric_time_dimension_spec.date_part, arg=select_expr)
# Apply date_part to time spine column select expression.
if requested_time_dimension_spec.date_part:
select_expr = SqlExtractExpression(date_part=requested_time_dimension_spec.date_part, arg=select_expr)
time_dim_spec = TimeDimensionSpec(
element_name=time_dim_instance.spec.element_name,
entity_links=time_dim_instance.spec.entity_links,
time_granularity=metric_time_dimension_spec.time_granularity,
date_part=metric_time_dimension_spec.date_part,
aggregation_state=time_dim_instance.spec.aggregation_state,
element_name=time_spine_dim_instance.spec.element_name,
entity_links=time_spine_dim_instance.spec.entity_links,
time_granularity=requested_time_dimension_spec.time_granularity,
date_part=requested_time_dimension_spec.date_part,
aggregation_state=time_spine_dim_instance.spec.aggregation_state,
)
time_dim_instance = TimeDimensionInstance(
defined_from=time_dim_instance.defined_from,
time_spine_dim_instance = TimeDimensionInstance(
defined_from=time_spine_dim_instance.defined_from,
associated_columns=(self._column_association_resolver.resolve_spec(time_dim_spec),),
spec=time_dim_spec,
)
metric_time_dimension_instances.append(time_dim_instance)
metric_time_select_columns.append(
SqlSelectColumn(expr=select_expr, column_alias=time_dim_instance.associated_column.column_name)
time_spine_dim_instances.append(time_spine_dim_instance)
time_spine_select_columns.append(
SqlSelectColumn(expr=select_expr, column_alias=time_spine_dim_instance.associated_column.column_name)
)
metric_time_instance_set = InstanceSet(time_dimension_instances=tuple(metric_time_dimension_instances))
time_spine_instance_set = InstanceSet(time_dimension_instances=tuple(time_spine_dim_instances))

return SqlDataSet(
instance_set=InstanceSet.merge([metric_time_instance_set, non_metric_time_parent_instance_set]),
instance_set=InstanceSet.merge([time_spine_instance_set, non_metric_time_parent_instance_set]),
sql_select_node=SqlSelectStatementNode(
description=node.description,
select_columns=tuple(metric_time_select_columns) + non_metric_time_select_columns,
select_columns=tuple(time_spine_select_columns) + parent_select_columns,
from_source=time_spine_dataset.sql_select_node,
from_source_alias=time_spine_alias,
joins_descs=(join_description,),
Expand Down
2 changes: 1 addition & 1 deletion metricflow/plan_conversion/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,5 +497,5 @@ def make_join_to_time_spine_join_description(
col_ref=SqlColumnReference(table_alias=parent_alias, column_name=metric_time_dimension_column_name)
),
),
join_type=SqlJoinType.INNER,
join_type=node.join_type,
)
2 changes: 2 additions & 0 deletions metricflow/specs/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,8 @@ class MetricInputMeasureSpec(SerializableDataclass):
measure_spec: MeasureSpec
constraint: Optional[WhereFilterSpec] = None
alias: Optional[str] = None
join_to_timespine: bool = False
fill_nulls_with: Optional[int] = None

@property
def post_aggregation_spec(self) -> MeasureSpec:
Expand Down
Loading

0 comments on commit 7bbb704

Please sign in to comment.