From 1b712618b96a39767904e33ebcadab6bfc7a48a0 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 18 Dec 2024 10:04:24 -0800 Subject: [PATCH] Add window_function attribute to TimeDimensionSpec This will allow us to track which specs have had a window function applied between DataflowPlan nodes --- .../dunder_column_association_resolver.py | 5 +++++ .../specs/time_dimension_spec.py | 22 +++++++++++++++++++ .../collection_helpers/test_pretty_print.py | 1 + 3 files changed, 28 insertions(+) diff --git a/metricflow-semantics/metricflow_semantics/specs/dunder_column_association_resolver.py b/metricflow-semantics/metricflow_semantics/specs/dunder_column_association_resolver.py index 8f2fead24..ad852f383 100644 --- a/metricflow-semantics/metricflow_semantics/specs/dunder_column_association_resolver.py +++ b/metricflow-semantics/metricflow_semantics/specs/dunder_column_association_resolver.py @@ -54,6 +54,11 @@ def visit_time_dimension_spec(self, time_dimension_spec: TimeDimensionSpec) -> C if time_dimension_spec.aggregation_state else "" ) + + ( + f"{DUNDER}{time_dimension_spec.window_function.value.lower()}" + if time_dimension_spec.window_function + else "" + ) ) def visit_entity_spec(self, entity_spec: EntitySpec) -> ColumnAssociation: # noqa: D102 diff --git a/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py b/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py index fd47c80a6..dec834adc 100644 --- a/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py +++ b/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py @@ -15,6 +15,7 @@ from metricflow_semantics.naming.linkable_spec_name import StructuredLinkableSpecName from metricflow_semantics.specs.dimension_spec import DimensionSpec from metricflow_semantics.specs.instance_spec import InstanceSpecVisitor +from metricflow_semantics.sql.sql_exprs import SqlWindowFunction from metricflow_semantics.time.granularity import ExpandedTimeGranularity from metricflow_semantics.visitor import VisitorOutputT @@ -91,6 +92,8 @@ class TimeDimensionSpec(DimensionSpec): # noqa: D101 # Used for semi-additive joins. Some more thought is needed, but this may be useful in InstanceSpec. aggregation_state: Optional[AggregationState] = None + window_function: Optional[SqlWindowFunction] = None + @property def without_first_entity_link(self) -> TimeDimensionSpec: # noqa: D102 assert len(self.entity_links) > 0, f"Spec does not have any entity links: {self}" @@ -99,6 +102,8 @@ def without_first_entity_link(self) -> TimeDimensionSpec: # noqa: D102 entity_links=self.entity_links[1:], time_granularity=self.time_granularity, date_part=self.date_part, + aggregation_state=self.aggregation_state, + window_function=self.window_function, ) @property @@ -108,6 +113,8 @@ def without_entity_links(self) -> TimeDimensionSpec: # noqa: D102 time_granularity=self.time_granularity, date_part=self.date_part, entity_links=(), + aggregation_state=self.aggregation_state, + window_function=self.window_function, ) @property @@ -153,6 +160,7 @@ def with_grain(self, time_granularity: ExpandedTimeGranularity) -> TimeDimension time_granularity=time_granularity, date_part=self.date_part, aggregation_state=self.aggregation_state, + window_function=self.window_function, ) def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102 @@ -162,6 +170,7 @@ def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102 time_granularity=ExpandedTimeGranularity.from_time_granularity(self.time_granularity.base_granularity), date_part=self.date_part, aggregation_state=self.aggregation_state, + window_function=self.window_function, ) def with_grain_and_date_part( # noqa: D102 @@ -173,6 +182,7 @@ def with_grain_and_date_part( # noqa: D102 time_granularity=time_granularity, date_part=date_part, aggregation_state=self.aggregation_state, + window_function=self.window_function, ) def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDimensionSpec: # noqa: D102 @@ -182,6 +192,17 @@ def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDim time_granularity=self.time_granularity, date_part=self.date_part, aggregation_state=aggregation_state, + window_function=self.window_function, + ) + + def with_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionSpec: # noqa: D102 + return TimeDimensionSpec( + element_name=self.element_name, + entity_links=self.entity_links, + time_granularity=self.time_granularity, + date_part=self.date_part, + aggregation_state=self.aggregation_state, + window_function=window_function, ) def comparison_key(self, exclude_fields: Sequence[TimeDimensionSpecField] = ()) -> TimeDimensionSpecComparisonKey: @@ -243,6 +264,7 @@ def with_entity_prefix(self, entity_prefix: EntityReference) -> TimeDimensionSpe time_granularity=self.time_granularity, date_part=self.date_part, aggregation_state=self.aggregation_state, + window_function=self.window_function, ) @staticmethod diff --git a/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_pretty_print.py b/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_pretty_print.py index c09422caa..86a4c446c 100644 --- a/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_pretty_print.py +++ b/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_pretty_print.py @@ -47,6 +47,7 @@ def test_classes() -> None: # noqa: D103 time_granularity=ExpandedTimeGranularity(name='day', base_granularity=DAY), date_part=None, aggregation_state=None, + window_function=None, ) """ ).rstrip()