Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Dec 4, 2024
1 parent fe9a711 commit e8fc9dd
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class AggregationState(Enum):
NON_AGGREGATED = "NON_AGGREGATED"
PARTIAL = "PARTIAL"
COMPLETE = "COMPLETE"
# Might want to move these to a new enum?
FIRST_VALUE = "FIRST_VALUE"
LAST_VALUE = "LAST_VALUE"
ROW_NUMBER = "ROW_NUMBER"

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}.{self.name}"
88 changes: 75 additions & 13 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,10 @@ def _build_derived_metric_output_node(
)
if metric_spec.has_time_offset and queried_agg_time_dimension_specs:
# TODO: move this to a helper method
time_spine_node = self._build_time_spine_node(queried_agg_time_dimension_specs)
time_spine_node = self._build_time_spine_node(
queried_time_spine_specs=queried_agg_time_dimension_specs,
offset_window=metric_spec.offset_window,
)
output_node = JoinToTimeSpineNode.create(
parent_node=output_node,
time_spine_node=time_spine_node,
Expand Down Expand Up @@ -1633,7 +1636,10 @@ def _build_aggregated_measure_from_measure_source_node(
measure_properties=measure_properties, required_time_spine_specs=base_queried_agg_time_dimension_specs
)
required_time_spine_specs = (join_on_time_dimension_spec,) + base_queried_agg_time_dimension_specs
time_spine_node = self._build_time_spine_node(required_time_spine_specs)
time_spine_node = self._build_time_spine_node(
queried_time_spine_specs=required_time_spine_specs,
offset_window=before_aggregation_time_spine_join_description.offset_window,
)
unaggregated_measure_node = JoinToTimeSpineNode.create(
parent_node=unaggregated_measure_node,
time_spine_node=time_spine_node,
Expand Down Expand Up @@ -1846,6 +1852,7 @@ def _build_time_spine_node(
queried_time_spine_specs: Sequence[TimeDimensionSpec],
where_filter_specs: Sequence[WhereFilterSpec] = (),
time_range_constraint: Optional[TimeRangeConstraint] = None,
offset_window: Optional[MetricTimeWindow] = None,
) -> DataflowPlanNode:
"""Return the time spine node needed to satisfy the specs."""
required_time_spine_spec_set = self.__get_required_linkable_specs(
Expand All @@ -1854,18 +1861,73 @@ def _build_time_spine_node(
)
required_time_spine_specs = required_time_spine_spec_set.time_dimension_specs

# TODO: support multiple time spines here. Build node on the one with the smallest base grain.
# Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine.
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
time_spine_node = TransformTimeDimensionsNode.create(
parent_node=self._choose_time_spine_read_node(time_spine_source),
requested_time_dimension_specs=required_time_spine_specs,
)
should_dedupe = False
if offset_window: # and it's a custom grain
# Are sets the right choice here?
all_queried_grains: Set[ExpandedTimeGranularity] = set()
queried_custom_specs: Tuple[TimeDimensionSpec, ...] = ()
queried_standard_specs: Tuple[TimeDimensionSpec, ...] = ()
for spec in queried_time_spine_specs:
all_queried_grains.add(spec.time_granularity)
if spec.time_granularity.is_custom_granularity:
queried_custom_specs += (spec,)
else:
queried_standard_specs += (spec,)

custom_grain_metric_time_spec = DataSet.metric_time_dimension_spec(
offset_window.granularity
) # this would be custom tho
time_spine_source = self._choose_time_spine_source(custom_grain_metric_time_spec)
time_spine_read_node = self._choose_time_spine_read_node(time_spine_source)
# TODO: make sure this is checking the correct granularity type once DSI is updated
if {spec.time_granularity for spec in queried_time_spine_specs} == {offset_window.granularity}:
# If querying with only the same grain as is used in the offset_window, can use a simpler plan.
offset_node = OffsetCustomGranularityNode.create(
parent_node=time_spine_read_node, offset_window=offset_window
)
time_spine_node = JoinToTimeSpineNode.create(
parent_node=offset_node,
time_spine_node=time_spine_read_node,
join_type=SqlJoinType.INNER,
join_on_time_dimension_spec=custom_grain_metric_time_spec,
)
else:
bounds_node = CustomGranularityBoundsNode.create(
parent_node=time_spine_read_node, offset_window=offset_window
)
# need to add a property to these specs to indicate that they are offset or bounds or something
filtered_bounds_node = FilterElementsNode.create(
parent_node=bounds_node, include_specs=bounds_node.specs, distinct=True
)
offset_bounds_node = OffsetCustomGranularityBoundsNode.create(parent_node=filtered_bounds_node)
time_spine_node = OffsetByCustomGranularityNode(
parent_node=offset_bounds_node, offset_window=offset_window
)
if queried_standard_specs:
time_spine_node = ApplyStandardGranularityNode.create(
parent_node=time_spine_node, time_dimension_specs=queried_standard_specs
)
# TODO: check if this join is needed for the same grain as is used in offset window. Later
for custom_spec in queried_custom_specs:
time_spine_node = JoinToCustomGranularityNode.create(
parent_node=time_spine_node, time_dimension_spec=custom_spec
)
# TODO: need TransformTimeDimensionsNode in either of the above paths?
else:
# TODO: support multiple time spines here. Build node on the one with the smallest base grain.
# Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine.
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
time_spine_node = TransformTimeDimensionsNode.create(
parent_node=self._choose_time_spine_read_node(time_spine_source),
requested_time_dimension_specs=required_time_spine_specs,
)

# If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping.
should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in {
spec.time_granularity for spec in queried_time_spine_specs
}
# If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping.
should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in {
spec.time_granularity for spec in queried_time_spine_specs
}

# -- JoinToCustomGranularityNode -- if needed to support another custom grain not covered by initial time spine

return self._build_pre_aggregation_plan(
source_node=time_spine_node,
Expand Down
117 changes: 117 additions & 0 deletions metricflow/dataflow/nodes/offset_by_custom_granularity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

from abc import ABC
from dataclasses import dataclass
from typing import Optional, Sequence

from dbt_semantic_interfaces.protocols import MetricTimeWindow
from dbt_semantic_interfaces.type_enums import TimeGranularity
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
class OffsetByCustomGranularityNode(DataflowPlanNode, ABC):
"""Offset.
Attributes:
offset_window: Time dimensions requested in the query.
join_type: Join type to use when joining to time spine.
join_on_time_dimension_spec: The time dimension to use in the join ON condition.
offset_window: Time window to offset the parent dataset by when joining to time spine.
offset_to_grain: Granularity period to offset the parent dataset to when joining to time spine.
"""

time_spine_node: DataflowPlanNode
requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec]
join_on_time_dimension_spec: TimeDimensionSpec
join_type: SqlJoinType
offset_window: Optional[MetricTimeWindow]
offset_to_grain: Optional[TimeGranularity]

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.parent_nodes) == 1

assert not (
self.offset_window and self.offset_to_grain
), "Can't set both offset_window and offset_to_grain when joining to time spine. Choose one or the other."
assert (
len(self.requested_agg_time_dimension_specs) > 0
), "Must have at least one value in requested_agg_time_dimension_specs for JoinToTimeSpineNode."

@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode,
time_spine_node: DataflowPlanNode,
requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec],
join_on_time_dimension_spec: TimeDimensionSpec,
join_type: SqlJoinType,
offset_window: Optional[MetricTimeWindow] = None,
offset_to_grain: Optional[TimeGranularity] = None,
) -> JoinToTimeSpineNode:
return JoinToTimeSpineNode(
parent_nodes=(parent_node,),
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=tuple(requested_agg_time_dimension_specs),
join_on_time_dimension_spec=join_on_time_dimension_spec,
join_type=join_type,
offset_window=offset_window,
offset_to_grain=offset_to_grain,
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_JOIN_TO_TIME_SPINE_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_join_to_time_spine_node(self)

@property
def description(self) -> str: # noqa: D102
return """Join to Time Spine Dataset"""

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
props = tuple(super().displayed_properties) + (
DisplayedProperty("requested_agg_time_dimension_specs", self.requested_agg_time_dimension_specs),
DisplayedProperty("join_on_time_dimension_spec", self.join_on_time_dimension_spec),
DisplayedProperty("join_type", self.join_type),
)
if self.offset_window:
props += (DisplayedProperty("offset_window", self.offset_window),)
if self.offset_to_grain:
props += (DisplayedProperty("offset_to_grain", self.offset_to_grain),)
return props

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
return self.parent_nodes[0]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return (
isinstance(other_node, self.__class__)
and other_node.offset_window == self.offset_window
and other_node.offset_to_grain == self.offset_to_grain
and other_node.requested_agg_time_dimension_specs == self.requested_agg_time_dimension_specs
and other_node.join_on_time_dimension_spec == self.join_on_time_dimension_spec
and other_node.join_type == self.join_type
)

def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> JoinToTimeSpineNode: # noqa: D102
assert len(new_parent_nodes) == 1
return JoinToTimeSpineNode.create(
parent_node=new_parent_nodes[0],
time_spine_node=self.time_spine_node,
requested_agg_time_dimension_specs=self.requested_agg_time_dimension_specs,
offset_window=self.offset_window,
offset_to_grain=self.offset_to_grain,
join_type=self.join_type,
join_on_time_dimension_spec=self.join_on_time_dimension_spec,
)
13 changes: 8 additions & 5 deletions metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from metricflow_semantics.specs.entity_spec import EntitySpec
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from typing_extensions import override
from dbt_semantic_interfaces.type_enums import DatePart
from metricflow_semantics.time.granularity import ExpandedTimeGranularity


from metricflow.dataset.dataset_classes import DataSet
from metricflow.sql.sql_plan import (
Expand Down Expand Up @@ -148,18 +151,18 @@ def instance_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) ->
return self.instances_for_time_dimensions((time_dimension_spec,))[0]

def instance_from_time_dimension_grain_and_date_part(
self, time_dimension_spec: TimeDimensionSpec
self, time_granularity: ExpandedTimeGranularity, date_part: Optional[DatePart]
) -> TimeDimensionInstance:
"""Find instance in dataset that matches the grain and date part of the given time dimension spec."""
"""Find instance in dataset that matches the given grain and date part."""
for time_dimension_instance in self.instance_set.time_dimension_instances:
if (
time_dimension_instance.spec.time_granularity == time_dimension_spec.time_granularity
and time_dimension_instance.spec.date_part == time_dimension_spec.date_part
time_dimension_instance.spec.time_granularity == time_granularity
and time_dimension_instance.spec.date_part == date_part
):
return time_dimension_instance

raise RuntimeError(
f"Did not find a time dimension instance with matching grain and date part for spec: {time_dimension_spec}\n"
f"Did not find a time dimension instance with grain {time_granularity} and date part {date_part}\n"
f"Instances available: {self.instance_set.time_dimension_instances}"
)

Expand Down
56 changes: 51 additions & 5 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,9 +1206,9 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr
spec=metric_time_dimension_spec,
)
)
output_column_to_input_column[
metric_time_dimension_column_association.column_name
] = matching_time_dimension_instance.associated_column.column_name
output_column_to_input_column[metric_time_dimension_column_association.column_name] = (
matching_time_dimension_instance.associated_column.column_name
)

output_instance_set = InstanceSet(
measure_instances=tuple(output_measure_instances),
Expand Down Expand Up @@ -1472,7 +1472,9 @@ def visit_transform_time_dimensions_node(self, node: TransformTimeDimensionsNode
specs_to_remove_from_parent: Set[TimeDimensionSpec] = set()
for spec in node.requested_time_dimension_specs:
# Find the instance in the parent data set with matching grain & date part.
old_instance = parent_data_set.instance_from_time_dimension_grain_and_date_part(spec)
old_instance = parent_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity=spec.time_granularity, date_part=spec.date_part
)

# Build new instance & select column to match requested spec.
new_instance = TimeDimensionInstance(
Expand Down Expand Up @@ -1821,7 +1823,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S

def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlDataSet: # noqa: D102
from_data_set = node.parent_node.accept(self)
parent_instance_set = from_data_set.instance_set # remove order by col
parent_instance_set = from_data_set.instance_set
parent_data_set_alias = self._next_unique_table_alias()

metric_instance = None
Expand Down Expand Up @@ -1948,6 +1950,50 @@ def strip_time_from_dt(ts: dt.datetime) -> dt.datetime:
),
)

def visit_cutom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> SqlDataSet:
from_data_set = node.parent_node.accept(self)
parent_instance_set = from_data_set.instance_set
parent_data_set_alias = self._next_unique_table_alias()

window_grain = node.time_granularity # ExpandedTimeGranularity
window_column = from_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity=window_grain, date_part=None
)

# Build new select columns to find the start and end of the custom grain.
window_column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=parent_data_set_alias, column_name=window_grain.name
)
base_column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=parent_data_set_alias, column_name=window_grain.base_column
)
bound_cols = tuple(
SqlWindowFunctionExpression.create(
sql_function=sql_function,
sql_function_args=(base_column_expr,),
partition_by_args=(window_column_expr,),
order_by_args=(base_column_expr,),
)
for sql_function in (SqlWindowFunction.FIRST_VALUE, SqlWindowFunction.LAST_VALUE)
)
# Build a column to track the number of rows from start of the custom grain period.
row_number_col = SqlWindowFunctionExpression.create(
sql_function=SqlWindowFunction.ROW_NUMBER,
partition_by_args=(window_column_expr,),
order_by_args=(base_column_expr,),
)

return SqlDataSet(
# do we need to change the instances? Maybe not.
instance_set=parent_instance_set, # Instances are unchanged.
sql_select_node=SqlSelectStatementNode.create(
description="", # TODO
select_columns=(row_number_col,) + bound_cols + from_data_set.checked_sql_select_node.select_columns,
from_source=from_data_set.checked_sql_select_node,
from_source_alias=parent_data_set_alias,
),
)


class DataflowNodeToSqlCteVisitor(DataflowNodeToSqlSubqueryVisitor):
"""Similar to `DataflowNodeToSqlSubqueryVisitor`, except that this converts specific nodes to CTEs.
Expand Down
2 changes: 1 addition & 1 deletion metricflow/plan_conversion/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def make_join_to_time_spine_join_description(
left_expr: SqlExpressionNode = SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=agg_time_dimension_column_name)
)
if node.offset_window:
if node.offset_window: # and not node.offset_window.granularity.is_custom_granularity:
left_expr = SqlSubtractTimeIntervalExpression.create(
arg=left_expr, count=node.offset_window.count, granularity=node.offset_window.granularity
)
Expand Down
1 change: 1 addition & 0 deletions metricflow/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,7 @@ class SqlWindowFunction(Enum):
FIRST_VALUE = "FIRST_VALUE"
LAST_VALUE = "LAST_VALUE"
AVERAGE = "AVG"
ROW_NUMBER = "ROW_NUMBER"

@property
def requires_ordering(self) -> bool:
Expand Down
Loading

0 comments on commit e8fc9dd

Please sign in to comment.