From c19072dac632a1619de383552f9c311bf2b12d55 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 18 Dec 2024 12:39:09 -0800 Subject: [PATCH] fixup! Add new dataflow plan nodes for custom offset windows --- metricflow/dataset/sql_dataset.py | 24 +++++++++++++++---- metricflow/plan_conversion/dataflow_to_sql.py | 6 ++--- metricflow/sql/sql_plan.py | 12 +++++++++- .../source_scan/test_source_scan_optimizer.py | 8 +++++++ 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/metricflow/dataset/sql_dataset.py b/metricflow/dataset/sql_dataset.py index afa559387..c7f4803b8 100644 --- a/metricflow/dataset/sql_dataset.py +++ b/metricflow/dataset/sql_dataset.py @@ -4,6 +4,7 @@ from typing import List, Optional, Sequence, Tuple from dbt_semantic_interfaces.references import SemanticModelReference +from dbt_semantic_interfaces.type_enums import DatePart from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set from metricflow_semantics.instances import EntityInstance, InstanceSet, MdoInstance, TimeDimensionInstance from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat @@ -12,6 +13,7 @@ from metricflow_semantics.specs.entity_spec import EntitySpec from metricflow_semantics.specs.instance_spec import InstanceSpec from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec +from metricflow_semantics.sql.sql_exprs import SqlWindowFunction from typing_extensions import override from metricflow.dataset.dataset_classes import DataSet @@ -165,18 +167,30 @@ def instance_for_spec(self, spec: InstanceSpec) -> MdoInstance: ) def instance_from_time_dimension_grain_and_date_part( - self, time_dimension_spec: TimeDimensionSpec + self, time_granularity_name: str, date_part: Optional[DatePart] ) -> TimeDimensionInstance: - """Find instance in dataset that matches the grain and date part of the given time dimension spec.""" + """Find instance in dataset that matches the given grain and date part.""" for time_dimension_instance in self.instance_set.time_dimension_instances: if ( - time_dimension_instance.spec.time_granularity == time_dimension_spec.time_granularity - and time_dimension_instance.spec.date_part == time_dimension_spec.date_part + time_dimension_instance.spec.time_granularity.name == time_granularity_name + and time_dimension_instance.spec.date_part == date_part + and time_dimension_instance.spec.window_function is None ): return time_dimension_instance raise RuntimeError( - f"Did not find a time dimension instance with matching grain and date part for spec: {time_dimension_spec}\n" + f"Did not find a time dimension instance with grain '{time_granularity_name}' and date part {date_part}\n" + f"Instances available: {self.instance_set.time_dimension_instances}" + ) + + def instance_from_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionInstance: + """Find instance in dataset that matches the given window function.""" + for time_dimension_instance in self.instance_set.time_dimension_instances: + if time_dimension_instance.spec.window_function is window_function: + return time_dimension_instance + + raise RuntimeError( + f"Did not find a time dimension instance with window function {window_function}.\n" f"Instances available: {self.instance_set.time_dimension_instances}" ) diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index aa4e34e61..7fe36dc04 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -1276,9 +1276,9 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr spec=metric_time_dimension_spec, ) ) - output_column_to_input_column[metric_time_dimension_column_association.column_name] = ( - matching_time_dimension_instance.associated_column.column_name - ) + output_column_to_input_column[ + metric_time_dimension_column_association.column_name + ] = matching_time_dimension_instance.associated_column.column_name output_instance_set = InstanceSet( measure_instances=tuple(output_measure_instances), diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index ff0b34c65..85e79b144 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -9,7 +9,7 @@ from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagId, DagNode, DisplayedProperty, MetricFlowDag -from metricflow_semantics.sql.sql_exprs import SqlExpressionNode +from metricflow_semantics.sql.sql_exprs import SqlColumnReferenceExpression, SqlExpressionNode from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.visitor import VisitorOutputT @@ -102,6 +102,16 @@ class SqlSelectColumn: # Always require a column alias for simplicity. column_alias: str + @staticmethod + def from_table_and_column_names(table_alias: str, column_name: str) -> SqlSelectColumn: + """Create a column that selects a column from a table by name.""" + return SqlSelectColumn( + expr=SqlColumnReferenceExpression.from_table_and_column_names( + column_name=column_name, table_alias=table_alias + ), + column_alias=column_name, + ) + @dataclass(frozen=True) class SqlJoinDescription: diff --git a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py index 05770806a..a66e5a9e5 100644 --- a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py +++ b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py @@ -24,6 +24,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -32,6 +33,7 @@ from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode +from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode @@ -114,6 +116,12 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod def visit_alias_specs_node(self, node: AliasSpecsNode) -> int: # noqa: D102 return self._sum_parents(node) + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> int: # noqa: D102 + return self._sum_parents(node) + + def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularityNode) -> int: # noqa: D102 + return self._sum_parents(node) + def count_source_nodes(self, dataflow_plan: DataflowPlan) -> int: # noqa: D102 return dataflow_plan.sink_node.accept(self)