From 195d21599c5d6082a7e1e1e50978ce4e8bf2d798 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Thu, 21 Nov 2024 15:20:50 -0800 Subject: [PATCH] Consolidate logic around choosing appropriate time spine source This will change later when we support multiple time spine nodes per query. For now, move the error to the core function so that we don't need to do error handling everywhere this gets used (which will be several places further up the stack). Also adds a helper to improve readability. --- .../time/time_spine_source.py | 12 +++++++++--- .../dataflow/builder/dataflow_plan_builder.py | 18 +++++++++--------- metricflow/plan_conversion/dataflow_to_sql.py | 8 +------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/time/time_spine_source.py b/metricflow-semantics/metricflow_semantics/time/time_spine_source.py index 8ca810e4d..c995c5ad8 100644 --- a/metricflow-semantics/metricflow_semantics/time/time_spine_source.py +++ b/metricflow-semantics/metricflow_semantics/time/time_spine_source.py @@ -99,10 +99,10 @@ def build_custom_granularities(time_spine_sources: Sequence[TimeSpineSource]) -> } @staticmethod - def choose_time_spine_sources( + def choose_time_spine_source( required_time_spine_specs: Sequence[TimeDimensionSpec], time_spine_sources: Dict[TimeGranularity, TimeSpineSource], - ) -> Sequence[TimeSpineSource]: + ) -> TimeSpineSource: """Determine which time spine sources to use to satisfy the given specs. Custom grains can only use the time spine where they are defined. For standard grains, this will choose the time @@ -145,7 +145,13 @@ def choose_time_spine_sources( if not required_time_spines.intersection(set(compatible_time_spines_for_standard_grains.values())): required_time_spines.add(time_spine_sources[max(compatible_time_spines_for_standard_grains)]) - return tuple(required_time_spines) + if len(required_time_spines) != 1: + raise RuntimeError( + "Multiple time spines are required to satisfy the specs, but only one is supported per query currently. " + f"Multiple will be supported in the future. Time spines required: {required_time_spines}." + ) + + return required_time_spines.pop() @property def data_set_description(self) -> str: diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index c45c12f66..49cd16455 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -1044,15 +1044,8 @@ def _find_source_node_recipe_non_cached( ) # If metric_time is requested without metrics, choose appropriate time spine node to select those values from. if linkable_specs_to_satisfy.metric_time_specs: - time_spine_sources = TimeSpineSource.choose_time_spine_sources( - required_time_spine_specs=linkable_specs_to_satisfy.metric_time_specs, - time_spine_sources=self._source_node_builder.time_spine_sources, - ) - assert len(time_spine_sources) == 1, ( - "Exactly one time spine source should have been selected for base grains." - "This indicates internal misconfiguration." - ) - time_spine_node = self._source_node_set.time_spine_nodes[time_spine_sources[0].base_granularity] + time_spine_source = self._choose_time_spine_source(linkable_specs_to_satisfy.metric_time_specs) + time_spine_node = self._source_node_set.time_spine_nodes[time_spine_source.base_granularity] candidate_nodes_for_right_side_of_join += [time_spine_node] candidate_nodes_for_left_side_of_join += [time_spine_node] default_join_type = SqlJoinType.FULL_OUTER @@ -1828,3 +1821,10 @@ def _build_semi_additive_join_node( agg_by_function=non_additive_dimension_spec.window_choice, queried_time_dimension_spec=queried_time_dimension_spec, ) + + def _choose_time_spine_source(self, required_time_spine_specs: Sequence[TimeDimensionSpec]) -> TimeSpineSource: + """Choose the time spine source that can satisfy the required time spine specs.""" + return TimeSpineSource.choose_time_spine_source( + required_time_spine_specs=required_time_spine_specs, + time_spine_sources=self._source_node_builder.time_spine_sources, + ) diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index a467453de..96688e629 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -335,15 +335,9 @@ def _make_time_spine_data_set( ] required_specs = queried_specs + specs_required_for_where_constraints - time_spine_sources = TimeSpineSource.choose_time_spine_sources( + time_spine_source = TimeSpineSource.choose_time_spine_source( required_time_spine_specs=required_specs, time_spine_sources=self._time_spine_sources ) - # TODO: handle multiple time spine joins - assert len(time_spine_sources) == 1, ( - "Join to time spine with custom granularity currently only supports one custom granularity per query. " - "Full feature coming soon." - ) - time_spine_source = time_spine_sources[0] time_spine_base_granularity = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) base_column_expr = SqlColumnReferenceExpression.from_table_and_column_names(