diff --git a/metricflow-semantics/tests_metricflow_semantics/model/semantics/test_metric_lookup.py b/metricflow-semantics/tests_metricflow_semantics/model/semantics/test_metric_lookup.py index d9942eeb4..b69c82d62 100644 --- a/metricflow-semantics/tests_metricflow_semantics/model/semantics/test_metric_lookup.py +++ b/metricflow-semantics/tests_metricflow_semantics/model/semantics/test_metric_lookup.py @@ -27,12 +27,7 @@ def test_min_queryable_time_granularity_for_different_agg_time_grains( # noqa: def test_custom_offset_window_for_metric( simple_semantic_manifest_lookup: SemanticManifestLookup, ) -> None: - """Test offset window with custom grain supplied. - - TODO: As of now, the functionality of an offset window with a custom grain is not supported in MF. - This test is added to show that at least the parsing is successful using a custom grain offset window. - Once support for that is added in MF + relevant tests, this test can be removed. - """ + """Test offset window with custom grain supplied.""" metric = simple_semantic_manifest_lookup.metric_lookup.get_metric(MetricReference("bookings_offset_martian_day")) assert len(metric.input_metrics) == 1 diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index e7456d777..f66667cbf 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -54,6 +54,7 @@ from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec from metricflow_semantics.specs.where_filter.where_filter_spec_set import WhereFilterSpecSet from metricflow_semantics.specs.where_filter.where_filter_transform import WhereSpecFactory +from metricflow_semantics.sql.sql_exprs import SqlWindowFunction from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.time.dateutil_adjuster import DateutilTimePeriodAdjuster @@ -84,6 +85,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -92,6 +94,7 @@ from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode +from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode @@ -658,13 +661,22 @@ def _build_derived_metric_output_node( ) if metric_spec.has_time_offset and queried_agg_time_dimension_specs: # TODO: move this to a helper method - time_spine_node = self._build_time_spine_node(queried_agg_time_dimension_specs) + time_spine_node = self._build_time_spine_node( + queried_time_spine_specs=queried_agg_time_dimension_specs, + offset_window=metric_spec.offset_window, + ) output_node = JoinToTimeSpineNode.create( metric_source_node=output_node, time_spine_node=time_spine_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, join_on_time_dimension_spec=self._sort_by_base_granularity(queried_agg_time_dimension_specs)[0], - offset_window=metric_spec.offset_window, + offset_window=( + metric_spec.offset_window + if metric_spec.offset_window + and metric_spec.offset_window.granularity + not in self._semantic_model_lookup.custom_granularity_names + else None + ), offset_to_grain=metric_spec.offset_to_grain, join_type=SqlJoinType.INNER, ) @@ -1651,13 +1663,22 @@ def _build_aggregated_measure_from_measure_source_node( required_time_spine_specs = base_queried_agg_time_dimension_specs if join_on_time_dimension_spec not in required_time_spine_specs: required_time_spine_specs = (join_on_time_dimension_spec,) + required_time_spine_specs - time_spine_node = self._build_time_spine_node(required_time_spine_specs) + time_spine_node = self._build_time_spine_node( + queried_time_spine_specs=required_time_spine_specs, + offset_window=before_aggregation_time_spine_join_description.offset_window, + ) unaggregated_measure_node = JoinToTimeSpineNode.create( metric_source_node=unaggregated_measure_node, time_spine_node=time_spine_node, requested_agg_time_dimension_specs=base_queried_agg_time_dimension_specs, join_on_time_dimension_spec=join_on_time_dimension_spec, - offset_window=before_aggregation_time_spine_join_description.offset_window, + offset_window=( + before_aggregation_time_spine_join_description.offset_window + if before_aggregation_time_spine_join_description.offset_window + and before_aggregation_time_spine_join_description.offset_window.granularity + not in self._semantic_model_lookup.custom_granularity_names + else None + ), offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain, join_type=before_aggregation_time_spine_join_description.join_type, ) @@ -1864,6 +1885,7 @@ def _build_time_spine_node( queried_time_spine_specs: Sequence[TimeDimensionSpec], where_filter_specs: Sequence[WhereFilterSpec] = (), time_range_constraint: Optional[TimeRangeConstraint] = None, + offset_window: Optional[MetricTimeWindow] = None, ) -> DataflowPlanNode: """Return the time spine node needed to satisfy the specs.""" required_time_spine_spec_set = self.__get_required_linkable_specs( @@ -1872,30 +1894,35 @@ def _build_time_spine_node( ) required_time_spine_specs = required_time_spine_spec_set.time_dimension_specs - # TODO: support multiple time spines here. Build node on the one with the smallest base grain. - # Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine. - time_spine_source = self._choose_time_spine_source(required_time_spine_specs) - read_node = self._choose_time_spine_read_node(time_spine_source) - time_spine_data_set = self._node_data_set_resolver.get_output_data_set(read_node) - - # Change the column aliases to match the specs that were requested in the query. - time_spine_node = AliasSpecsNode.create( - parent_node=read_node, - change_specs=tuple( - SpecToAlias( - input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part( - time_granularity_name=required_spec.time_granularity.name, date_part=required_spec.date_part - ).spec, - output_spec=required_spec, - ) - for required_spec in required_time_spine_specs - ), - ) - - # If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping. - should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in { - spec.time_granularity for spec in queried_time_spine_specs - } + should_dedupe = False + if offset_window and offset_window.granularity in self._semantic_model_lookup._custom_granularities: + time_spine_node = self._build_custom_offset_time_spine_node( + offset_window=offset_window, required_time_spine_specs=required_time_spine_specs + ) + else: + # For simpler time spine queries, choose the appropriate time spine node and apply requested aliases. + time_spine_source = self._choose_time_spine_source(required_time_spine_specs) + # TODO: support multiple time spines here. Build node on the one with the smallest base grain. + # Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine. + read_node = self._choose_time_spine_read_node(time_spine_source) + time_spine_data_set = self._node_data_set_resolver.get_output_data_set(read_node) + # Change the column aliases to match the specs that were requested in the query. + time_spine_node = AliasSpecsNode.create( + parent_node=read_node, + change_specs=tuple( + SpecToAlias( + input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part( + time_granularity_name=required_spec.time_granularity.name, date_part=required_spec.date_part + ).spec, + output_spec=required_spec, + ) + for required_spec in required_time_spine_specs + ), + ) + # If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping. + should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in { + spec.time_granularity for spec in queried_time_spine_specs + } return self._build_pre_aggregation_plan( source_node=time_spine_node, @@ -1905,6 +1932,44 @@ def _build_time_spine_node( distinct=should_dedupe, ) + def _build_custom_offset_time_spine_node( + self, offset_window: MetricTimeWindow, required_time_spine_specs: Tuple[TimeDimensionSpec, ...] + ) -> DataflowPlanNode: + # Build time spine node that offsets agg time dimensions by a custom grain. + custom_grain = self._semantic_model_lookup._custom_granularities[offset_window.granularity] + time_spine_source = self._choose_time_spine_source((DataSet.metric_time_dimension_spec(custom_grain),)) + time_spine_read_node = self._choose_time_spine_read_node(time_spine_source) + if {spec.time_granularity for spec in required_time_spine_specs} == {custom_grain}: + # If querying with only the same grain as is used in the offset_window, can use a simpler plan. + raise NotImplementedError + else: + # For custom offset windows queried with other granularities, first, build CustomGranularityBoundsNode. + # This will be used twice in the output node, and ideally will be turned into a CTE. + bounds_node = CustomGranularityBoundsNode.create( + parent_node=time_spine_read_node, custom_granularity_name=custom_grain.name + ) + # Build a FilterElementsNode from bounds node to get required unique rows. + bounds_data_set = self._node_data_set_resolver.get_output_data_set(bounds_node) + bounds_specs = tuple( + bounds_data_set.instance_from_window_function(window_func).spec + for window_func in (SqlWindowFunction.FIRST_VALUE, SqlWindowFunction.LAST_VALUE) + ) + custom_grain_spec = bounds_data_set.instance_from_time_dimension_grain_and_date_part( + time_granularity_name=custom_grain.name, date_part=None + ).spec + filter_elements_node = FilterElementsNode.create( + parent_node=bounds_node, + include_specs=InstanceSpecSet(time_dimension_specs=(custom_grain_spec,) + bounds_specs), + distinct=True, + ) + # Pass both the CustomGranularityBoundsNode and the FilterElementsNode into the OffsetByCustomGranularityNode. + return OffsetByCustomGranularityNode.create( + custom_granularity_bounds_node=bounds_node, + filter_elements_node=filter_elements_node, + offset_window=offset_window, + required_time_spine_specs=required_time_spine_specs, + ) + def _sort_by_base_granularity(self, time_dimension_specs: Sequence[TimeDimensionSpec]) -> List[TimeDimensionSpec]: """Sort the time dimensions by their base granularity.