Skip to content

Commit

Permalink
Build DataflowPlan for custom offset window with most grains
Browse files Browse the repository at this point in the history
This is the dataflow plan that will be used if the custom grain is queried with any grains that aren't the same as the grain used in the offset window.
  • Loading branch information
courtneyholcomb committed Dec 18, 2024
1 parent 57ba604 commit 1a5b5e4
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,7 @@ def test_min_queryable_time_granularity_for_different_agg_time_grains( # noqa:
def test_custom_offset_window_for_metric(
simple_semantic_manifest_lookup: SemanticManifestLookup,
) -> None:
"""Test offset window with custom grain supplied.
TODO: As of now, the functionality of an offset window with a custom grain is not supported in MF.
This test is added to show that at least the parsing is successful using a custom grain offset window.
Once support for that is added in MF + relevant tests, this test can be removed.
"""
"""Test offset window with custom grain supplied."""
metric = simple_semantic_manifest_lookup.metric_lookup.get_metric(MetricReference("bookings_offset_martian_day"))

assert len(metric.input_metrics) == 1
Expand Down
121 changes: 93 additions & 28 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.specs.where_filter.where_filter_spec_set import WhereFilterSpecSet
from metricflow_semantics.specs.where_filter.where_filter_transform import WhereSpecFactory
from metricflow_semantics.sql.sql_exprs import SqlWindowFunction
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.time.dateutil_adjuster import DateutilTimePeriodAdjuster
Expand Down Expand Up @@ -84,6 +85,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand All @@ -92,6 +94,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -658,13 +661,22 @@ def _build_derived_metric_output_node(
)
if metric_spec.has_time_offset and queried_agg_time_dimension_specs:
# TODO: move this to a helper method
time_spine_node = self._build_time_spine_node(queried_agg_time_dimension_specs)
time_spine_node = self._build_time_spine_node(
queried_time_spine_specs=queried_agg_time_dimension_specs,
offset_window=metric_spec.offset_window,
)
output_node = JoinToTimeSpineNode.create(
metric_source_node=output_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
join_on_time_dimension_spec=self._sort_by_base_granularity(queried_agg_time_dimension_specs)[0],
offset_window=metric_spec.offset_window,
offset_window=(
metric_spec.offset_window
if metric_spec.offset_window
and metric_spec.offset_window.granularity
not in self._semantic_model_lookup.custom_granularity_names
else None
),
offset_to_grain=metric_spec.offset_to_grain,
join_type=SqlJoinType.INNER,
)
Expand Down Expand Up @@ -1651,13 +1663,22 @@ def _build_aggregated_measure_from_measure_source_node(
required_time_spine_specs = base_queried_agg_time_dimension_specs
if join_on_time_dimension_spec not in required_time_spine_specs:
required_time_spine_specs = (join_on_time_dimension_spec,) + required_time_spine_specs
time_spine_node = self._build_time_spine_node(required_time_spine_specs)
time_spine_node = self._build_time_spine_node(
queried_time_spine_specs=required_time_spine_specs,
offset_window=before_aggregation_time_spine_join_description.offset_window,
)
unaggregated_measure_node = JoinToTimeSpineNode.create(
metric_source_node=unaggregated_measure_node,
time_spine_node=time_spine_node,
requested_agg_time_dimension_specs=base_queried_agg_time_dimension_specs,
join_on_time_dimension_spec=join_on_time_dimension_spec,
offset_window=before_aggregation_time_spine_join_description.offset_window,
offset_window=(
before_aggregation_time_spine_join_description.offset_window
if before_aggregation_time_spine_join_description.offset_window
and before_aggregation_time_spine_join_description.offset_window.granularity
not in self._semantic_model_lookup.custom_granularity_names
else None
),
offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain,
join_type=before_aggregation_time_spine_join_description.join_type,
)
Expand Down Expand Up @@ -1864,6 +1885,7 @@ def _build_time_spine_node(
queried_time_spine_specs: Sequence[TimeDimensionSpec],
where_filter_specs: Sequence[WhereFilterSpec] = (),
time_range_constraint: Optional[TimeRangeConstraint] = None,
offset_window: Optional[MetricTimeWindow] = None,
) -> DataflowPlanNode:
"""Return the time spine node needed to satisfy the specs."""
required_time_spine_spec_set = self.__get_required_linkable_specs(
Expand All @@ -1872,30 +1894,35 @@ def _build_time_spine_node(
)
required_time_spine_specs = required_time_spine_spec_set.time_dimension_specs

# TODO: support multiple time spines here. Build node on the one with the smallest base grain.
# Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine.
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
read_node = self._choose_time_spine_read_node(time_spine_source)
time_spine_data_set = self._node_data_set_resolver.get_output_data_set(read_node)

# Change the column aliases to match the specs that were requested in the query.
time_spine_node = AliasSpecsNode.create(
parent_node=read_node,
change_specs=tuple(
SpecToAlias(
input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=required_spec.time_granularity.name, date_part=required_spec.date_part
).spec,
output_spec=required_spec,
)
for required_spec in required_time_spine_specs
),
)

# If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping.
should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in {
spec.time_granularity for spec in queried_time_spine_specs
}
should_dedupe = False
if offset_window and offset_window.granularity in self._semantic_model_lookup._custom_granularities:
time_spine_node = self._build_custom_offset_time_spine_node(
offset_window=offset_window, required_time_spine_specs=required_time_spine_specs
)
else:
# For simpler time spine queries, choose the appropriate time spine node and apply requested aliases.
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
# TODO: support multiple time spines here. Build node on the one with the smallest base grain.
# Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine.
read_node = self._choose_time_spine_read_node(time_spine_source)
time_spine_data_set = self._node_data_set_resolver.get_output_data_set(read_node)
# Change the column aliases to match the specs that were requested in the query.
time_spine_node = AliasSpecsNode.create(
parent_node=read_node,
change_specs=tuple(
SpecToAlias(
input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=required_spec.time_granularity.name, date_part=required_spec.date_part
).spec,
output_spec=required_spec,
)
for required_spec in required_time_spine_specs
),
)
# If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping.
should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in {
spec.time_granularity for spec in queried_time_spine_specs
}

return self._build_pre_aggregation_plan(
source_node=time_spine_node,
Expand All @@ -1905,6 +1932,44 @@ def _build_time_spine_node(
distinct=should_dedupe,
)

def _build_custom_offset_time_spine_node(
self, offset_window: MetricTimeWindow, required_time_spine_specs: Tuple[TimeDimensionSpec, ...]
) -> DataflowPlanNode:
# Build time spine node that offsets agg time dimensions by a custom grain.
custom_grain = self._semantic_model_lookup._custom_granularities[offset_window.granularity]
time_spine_source = self._choose_time_spine_source((DataSet.metric_time_dimension_spec(custom_grain),))
time_spine_read_node = self._choose_time_spine_read_node(time_spine_source)
if {spec.time_granularity for spec in required_time_spine_specs} == {custom_grain}:
# If querying with only the same grain as is used in the offset_window, can use a simpler plan.
raise NotImplementedError
else:
# For custom offset windows queried with other granularities, first, build CustomGranularityBoundsNode.
# This will be used twice in the output node, and ideally will be turned into a CTE.
bounds_node = CustomGranularityBoundsNode.create(
parent_node=time_spine_read_node, custom_granularity_name=custom_grain.name
)
# Build a FilterElementsNode from bounds node to get required unique rows.
bounds_data_set = self._node_data_set_resolver.get_output_data_set(bounds_node)
bounds_specs = tuple(
bounds_data_set.instance_from_window_function(window_func).spec
for window_func in (SqlWindowFunction.FIRST_VALUE, SqlWindowFunction.LAST_VALUE)
)
custom_grain_spec = bounds_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=custom_grain.name, date_part=None
).spec
filter_elements_node = FilterElementsNode.create(
parent_node=bounds_node,
include_specs=InstanceSpecSet(time_dimension_specs=(custom_grain_spec,) + bounds_specs),
distinct=True,
)
# Pass both the CustomGranularityBoundsNode and the FilterElementsNode into the OffsetByCustomGranularityNode.
return OffsetByCustomGranularityNode.create(
custom_granularity_bounds_node=bounds_node,
filter_elements_node=filter_elements_node,
offset_window=offset_window,
required_time_spine_specs=required_time_spine_specs,
)

def _sort_by_base_granularity(self, time_dimension_specs: Sequence[TimeDimensionSpec]) -> List[TimeDimensionSpec]:
"""Sort the time dimensions by their base granularity.
Expand Down

0 comments on commit 1a5b5e4

Please sign in to comment.