From 7de2bd30f91d6ae03782026fea77e4e35b97be7b Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 18 Dec 2024 12:25:58 -0800 Subject: [PATCH] Add new dataflow plan nodes for custom offset windows --- .../metricflow_semantics/dag/id_prefix.py | 2 + .../dataflow/builder/dataflow_plan_builder.py | 4 +- metricflow/dataflow/dataflow_plan_visitor.py | 22 ++ .../nodes/custom_granularity_bounds.py | 64 ++++ .../nodes/offset_by_custom_granularity.py | 95 ++++++ .../optimizer/predicate_pushdown_optimizer.py | 12 + .../source_scan/cm_branch_combiner.py | 14 + .../source_scan/source_scan_optimizer.py | 14 + metricflow/dataset/sql_dataset.py | 24 +- metricflow/execution/dataflow_to_execution.py | 12 + metricflow/plan_conversion/dataflow_to_sql.py | 306 +++++++++++++++++- metricflow/sql/sql_plan.py | 12 +- .../source_scan/test_source_scan_optimizer.py | 8 + 13 files changed, 581 insertions(+), 8 deletions(-) create mode 100644 metricflow/dataflow/nodes/custom_granularity_bounds.py create mode 100644 metricflow/dataflow/nodes/offset_by_custom_granularity.py diff --git a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py index 61fdcc5f7..285052d3a 100644 --- a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py +++ b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py @@ -56,6 +56,8 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper): DATAFLOW_NODE_JOIN_CONVERSION_EVENTS_PREFIX = "jce" DATAFLOW_NODE_WINDOW_REAGGREGATION_ID_PREFIX = "wr" DATAFLOW_NODE_ALIAS_SPECS_ID_PREFIX = "as" + DATAFLOW_NODE_CUSTOM_GRANULARITY_BOUNDS_ID_PREFIX = "cgb" + DATAFLOW_NODE_OFFSET_BY_CUSTOMG_GRANULARITY_ID_PREFIX = "obcg" SQL_EXPR_COLUMN_REFERENCE_ID_PREFIX = "cr" SQL_EXPR_COMPARISON_ID_PREFIX = "cmp" diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index b44271fa9..e7456d777 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -1883,7 +1883,9 @@ def _build_time_spine_node( parent_node=read_node, change_specs=tuple( SpecToAlias( - input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(required_spec).spec, + input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part( + time_granularity_name=required_spec.time_granularity.name, date_part=required_spec.date_part + ).spec, output_spec=required_spec, ) for required_spec in required_time_spine_specs diff --git a/metricflow/dataflow/dataflow_plan_visitor.py b/metricflow/dataflow/dataflow_plan_visitor.py index 412170a53..1e3a86bfe 100644 --- a/metricflow/dataflow/dataflow_plan_visitor.py +++ b/metricflow/dataflow/dataflow_plan_visitor.py @@ -15,6 +15,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode + from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -23,6 +24,7 @@ from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode + from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode @@ -126,6 +128,16 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102 raise NotImplementedError + @abstractmethod + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102 + raise NotImplementedError + + @abstractmethod + def visit_offset_by_custom_granularity_node( # noqa: D102 + self, node: OffsetByCustomGranularityNode + ) -> VisitorOutputT: + raise NotImplementedError + class DataflowPlanNodeVisitorWithDefaultHandler(DataflowPlanNodeVisitor[VisitorOutputT], Generic[VisitorOutputT]): """Similar to `DataflowPlanNodeVisitor`, but with an abstract default handler that gets called for each node. @@ -222,3 +234,13 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod @override def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102 return self._default_handler(node) + + @override + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102 + return self._default_handler(node) + + @override + def visit_offset_by_custom_granularity_node( # noqa: D102 + self, node: OffsetByCustomGranularityNode + ) -> VisitorOutputT: + return self._default_handler(node) diff --git a/metricflow/dataflow/nodes/custom_granularity_bounds.py b/metricflow/dataflow/nodes/custom_granularity_bounds.py new file mode 100644 index 000000000..5dbde2a88 --- /dev/null +++ b/metricflow/dataflow/nodes/custom_granularity_bounds.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from abc import ABC +from dataclasses import dataclass +from typing import Sequence + +from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix +from metricflow_semantics.dag.mf_dag import DisplayedProperty +from metricflow_semantics.visitor import VisitorOutputT + +from metricflow.dataflow.dataflow_plan import DataflowPlanNode +from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor + + +@dataclass(frozen=True, eq=False) +class CustomGranularityBoundsNode(DataflowPlanNode, ABC): + """Calculate the start and end of a custom granularity period and each row number within that period.""" + + custom_granularity_name: str + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create( # noqa: D102 + parent_node: DataflowPlanNode, custom_granularity_name: str + ) -> CustomGranularityBoundsNode: + return CustomGranularityBoundsNode(parent_nodes=(parent_node,), custom_granularity_name=custom_granularity_name) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.DATAFLOW_NODE_CUSTOM_GRANULARITY_BOUNDS_ID_PREFIX + + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_custom_granularity_bounds_node(self) + + @property + def description(self) -> str: # noqa: D102 + return """Calculate Custom Granularity Bounds""" + + @property + def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 + return tuple(super().displayed_properties) + ( + DisplayedProperty("custom_granularity_name", self.custom_granularity_name), + ) + + @property + def parent_node(self) -> DataflowPlanNode: # noqa: D102 + return self.parent_nodes[0] + + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 + return ( + isinstance(other_node, self.__class__) + and other_node.custom_granularity_name == self.custom_granularity_name + ) + + def with_new_parents( # noqa: D102 + self, new_parent_nodes: Sequence[DataflowPlanNode] + ) -> CustomGranularityBoundsNode: + assert len(new_parent_nodes) == 1 + return CustomGranularityBoundsNode.create( + parent_node=new_parent_nodes[0], custom_granularity_name=self.custom_granularity_name + ) diff --git a/metricflow/dataflow/nodes/offset_by_custom_granularity.py b/metricflow/dataflow/nodes/offset_by_custom_granularity.py new file mode 100644 index 000000000..8161af3c2 --- /dev/null +++ b/metricflow/dataflow/nodes/offset_by_custom_granularity.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from abc import ABC +from dataclasses import dataclass +from typing import Optional, Sequence + +from dbt_semantic_interfaces.protocols.metric import MetricTimeWindow +from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix +from metricflow_semantics.dag.mf_dag import DisplayedProperty +from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec +from metricflow_semantics.visitor import VisitorOutputT + +from metricflow.dataflow.dataflow_plan import DataflowPlanNode +from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode +from metricflow.dataflow.nodes.filter_elements import FilterElementsNode + + +@dataclass(frozen=True, eq=False) +class OffsetByCustomGranularityNode(DataflowPlanNode, ABC): + """For a given custom grain, offset its base grain by the requested number of custom grain periods. + + Only accepts CustomGranularityBoundsNode as parent node. + """ + + offset_window: MetricTimeWindow + required_time_spine_specs: Sequence[TimeDimensionSpec] + custom_granularity_bounds_node: CustomGranularityBoundsNode + filter_elements_node: FilterElementsNode + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + + @staticmethod + def create( # noqa: D102 + custom_granularity_bounds_node: CustomGranularityBoundsNode, + filter_elements_node: FilterElementsNode, + offset_window: MetricTimeWindow, + required_time_spine_specs: Sequence[TimeDimensionSpec], + ) -> OffsetByCustomGranularityNode: + return OffsetByCustomGranularityNode( + parent_nodes=(custom_granularity_bounds_node, filter_elements_node), + custom_granularity_bounds_node=custom_granularity_bounds_node, + filter_elements_node=filter_elements_node, + offset_window=offset_window, + required_time_spine_specs=required_time_spine_specs, + ) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.DATAFLOW_NODE_OFFSET_BY_CUSTOMG_GRANULARITY_ID_PREFIX + + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_offset_by_custom_granularity_node(self) + + @property + def description(self) -> str: # noqa: D102 + return """Offset Base Granularity By Custom Granularity Period(s)""" + + @property + def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 + return tuple(super().displayed_properties) + ( + DisplayedProperty("offset_window", self.offset_window), + DisplayedProperty("required_time_spine_specs", self.required_time_spine_specs), + ) + + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 + return ( + isinstance(other_node, self.__class__) + and other_node.offset_window == self.offset_window + and other_node.required_time_spine_specs == self.required_time_spine_specs + ) + + def with_new_parents( # noqa: D102 + self, new_parent_nodes: Sequence[DataflowPlanNode] + ) -> OffsetByCustomGranularityNode: + custom_granularity_bounds_node: Optional[CustomGranularityBoundsNode] = None + filter_elements_node: Optional[FilterElementsNode] = None + for parent_node in new_parent_nodes: + if isinstance(parent_node, CustomGranularityBoundsNode): + custom_granularity_bounds_node = parent_node + elif isinstance(parent_node, FilterElementsNode): + filter_elements_node = parent_node + assert custom_granularity_bounds_node and filter_elements_node, ( + "Can't rewrite OffsetByCustomGranularityNode because the node requires a CustomGranularityBoundsNode and a " + f"FilterElementsNode as parents. Instead, got: {new_parent_nodes}" + ) + + return OffsetByCustomGranularityNode( + parent_nodes=tuple(new_parent_nodes), + custom_granularity_bounds_node=custom_granularity_bounds_node, + filter_elements_node=filter_elements_node, + offset_window=self.offset_window, + required_time_spine_specs=self.required_time_spine_specs, + ) diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index 223964af4..97a9eae41 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -23,6 +23,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -31,6 +32,7 @@ from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode +from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode @@ -472,6 +474,16 @@ def visit_join_to_custom_granularity_node( # noqa: D102 def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102 raise NotImplementedError + def visit_custom_granularity_bounds_node( # noqa: D102 + self, node: CustomGranularityBoundsNode + ) -> OptimizeBranchResult: + raise NotImplementedError + + def visit_offset_by_custom_granularity_node( # noqa: D102 + self, node: OffsetByCustomGranularityNode + ) -> OptimizeBranchResult: + raise NotImplementedError + def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult: """Handles pushdown state propagation for the standard join node type. diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index 3209e34b8..d153899d9 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -17,6 +17,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -25,6 +26,7 @@ from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode +from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode @@ -472,3 +474,15 @@ def visit_min_max_node(self, node: MinMaxNode) -> ComputeMetricsBranchCombinerRe def visit_alias_specs_node(self, node: AliasSpecsNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102 self._log_visit_node_type(node) return self._default_handler(node) + + def visit_custom_granularity_bounds_node( # noqa: D102 + self, node: CustomGranularityBoundsNode + ) -> ComputeMetricsBranchCombinerResult: + self._log_visit_node_type(node) + return self._default_handler(node) + + def visit_offset_by_custom_granularity_node( # noqa: D102 + self, node: OffsetByCustomGranularityNode + ) -> ComputeMetricsBranchCombinerResult: + self._log_visit_node_type(node) + return self._default_handler(node) diff --git a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py index 95c0aeec3..4fea885d5 100644 --- a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py +++ b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py @@ -19,6 +19,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -27,6 +28,7 @@ from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode +from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode @@ -363,3 +365,15 @@ def visit_min_max_node(self, node: MinMaxNode) -> OptimizeBranchResult: # noqa: def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102 self._log_visit_node_type(node) return self._default_base_output_handler(node) + + def visit_custom_granularity_bounds_node( # noqa: D102 + self, node: CustomGranularityBoundsNode + ) -> OptimizeBranchResult: + self._log_visit_node_type(node) + return self._default_base_output_handler(node) + + def visit_offset_by_custom_granularity_node( # noqa: D102 + self, node: OffsetByCustomGranularityNode + ) -> OptimizeBranchResult: + self._log_visit_node_type(node) + return self._default_base_output_handler(node) diff --git a/metricflow/dataset/sql_dataset.py b/metricflow/dataset/sql_dataset.py index afa559387..c7f4803b8 100644 --- a/metricflow/dataset/sql_dataset.py +++ b/metricflow/dataset/sql_dataset.py @@ -4,6 +4,7 @@ from typing import List, Optional, Sequence, Tuple from dbt_semantic_interfaces.references import SemanticModelReference +from dbt_semantic_interfaces.type_enums import DatePart from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set from metricflow_semantics.instances import EntityInstance, InstanceSet, MdoInstance, TimeDimensionInstance from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat @@ -12,6 +13,7 @@ from metricflow_semantics.specs.entity_spec import EntitySpec from metricflow_semantics.specs.instance_spec import InstanceSpec from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec +from metricflow_semantics.sql.sql_exprs import SqlWindowFunction from typing_extensions import override from metricflow.dataset.dataset_classes import DataSet @@ -165,18 +167,30 @@ def instance_for_spec(self, spec: InstanceSpec) -> MdoInstance: ) def instance_from_time_dimension_grain_and_date_part( - self, time_dimension_spec: TimeDimensionSpec + self, time_granularity_name: str, date_part: Optional[DatePart] ) -> TimeDimensionInstance: - """Find instance in dataset that matches the grain and date part of the given time dimension spec.""" + """Find instance in dataset that matches the given grain and date part.""" for time_dimension_instance in self.instance_set.time_dimension_instances: if ( - time_dimension_instance.spec.time_granularity == time_dimension_spec.time_granularity - and time_dimension_instance.spec.date_part == time_dimension_spec.date_part + time_dimension_instance.spec.time_granularity.name == time_granularity_name + and time_dimension_instance.spec.date_part == date_part + and time_dimension_instance.spec.window_function is None ): return time_dimension_instance raise RuntimeError( - f"Did not find a time dimension instance with matching grain and date part for spec: {time_dimension_spec}\n" + f"Did not find a time dimension instance with grain '{time_granularity_name}' and date part {date_part}\n" + f"Instances available: {self.instance_set.time_dimension_instances}" + ) + + def instance_from_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionInstance: + """Find instance in dataset that matches the given window function.""" + for time_dimension_instance in self.instance_set.time_dimension_instances: + if time_dimension_instance.spec.window_function is window_function: + return time_dimension_instance + + raise RuntimeError( + f"Did not find a time dimension instance with window function {window_function}.\n" f"Instances available: {self.instance_set.time_dimension_instances}" ) diff --git a/metricflow/execution/dataflow_to_execution.py b/metricflow/execution/dataflow_to_execution.py index 1c2aa2fd9..8d1f5cfb3 100644 --- a/metricflow/execution/dataflow_to_execution.py +++ b/metricflow/execution/dataflow_to_execution.py @@ -16,6 +16,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -24,6 +25,7 @@ from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode +from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode @@ -205,3 +207,13 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod @override def visit_alias_specs_node(self, node: AliasSpecsNode) -> ConvertToExecutionPlanResult: raise NotImplementedError + + @override + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> ConvertToExecutionPlanResult: + raise NotImplementedError + + @override + def visit_offset_by_custom_granularity_node( + self, node: OffsetByCustomGranularityNode + ) -> ConvertToExecutionPlanResult: + raise NotImplementedError diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index cd6b92a2b..b03d76e4d 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -6,6 +6,7 @@ from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar from dbt_semantic_interfaces.enum_extension import assert_values_exhausted +from dbt_semantic_interfaces.naming.keywords import DUNDER from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType from dbt_semantic_interfaces.references import MetricModelReference, SemanticModelElementReference from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType @@ -38,8 +39,12 @@ from metricflow_semantics.specs.spec_set import InstanceSpecSet from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec from metricflow_semantics.sql.sql_exprs import ( + SqlAddTimeExpression, SqlAggregateFunctionExpression, + SqlArithmeticExpression, + SqlArithmeticOperator, SqlBetweenExpression, + SqlCaseExpression, SqlColumnReference, SqlColumnReferenceExpression, SqlComparison, @@ -50,6 +55,7 @@ SqlFunction, SqlFunctionExpression, SqlGenerateUuidExpression, + SqlIntegerExpression, SqlLogicalExpression, SqlLogicalOperator, SqlRatioComputationExpression, @@ -77,6 +83,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -85,6 +92,7 @@ from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode +from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode @@ -1888,7 +1896,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlDataSet: # noqa: D102 from_data_set = node.parent_node.accept(self) - parent_instance_set = from_data_set.instance_set # remove order by col + parent_instance_set = from_data_set.instance_set parent_data_set_alias = self._next_unique_table_alias() metric_instance = None @@ -2015,6 +2023,290 @@ def strip_time_from_dt(ts: dt.datetime) -> dt.datetime: ), ) + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> SqlDataSet: + """Build columns that will be needed for custom offset windows. + + This includes columns that represent the start and end of a custom grain period, as well as the row number of the base + grain within each period. For example, the columns might look like: + + SELECT + {{ existing columns }}, + FIRST_VALUE(ds) OVER (PARTITION BY fiscal_quarter ORDER BY ds) AS ds__fiscal_quarter__first_value, + LAST_VALUE(ds) OVER (PARTITION BY fiscal_quarter ORDER BY ds) AS ds__fiscal_quarter__last_value, + ROW_NUMBER() OVER (PARTITION BY fiscal_quarter ORDER BY ds) AS ds__day__row_number + FROM time_spine_read_node + """ + parent_data_set = node.parent_node.accept(self) + parent_instance_set = parent_data_set.instance_set + parent_data_set_alias = self._next_unique_table_alias() + + custom_granularity_name = node.custom_granularity_name + time_spine = self._get_time_spine_for_custom_granularity(custom_granularity_name) + custom_grain_instance_from_parent = parent_data_set.instance_from_time_dimension_grain_and_date_part( + time_granularity_name=custom_granularity_name, date_part=None + ) + base_grain_instance_from_parent = parent_data_set.instance_from_time_dimension_grain_and_date_part( + time_granularity_name=time_spine.base_granularity.value, date_part=None + ) + custom_column_expr = SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=parent_data_set_alias, + column_name=custom_grain_instance_from_parent.associated_column.column_name, + ) + base_column_expr = SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=parent_data_set_alias, column_name=base_grain_instance_from_parent.associated_column.column_name + ) + + new_instances: Tuple[TimeDimensionInstance, ...] = tuple() + new_select_columns: Tuple[SqlSelectColumn, ...] = tuple() + + # Build columns that get start and end of the custom grain period. + # Ex: "FIRST_VALUE(ds) OVER (PARTITION BY martian_day ORDER BY ds) AS ds__fiscal_quarter__first_value" + for window_func in (SqlWindowFunction.FIRST_VALUE, SqlWindowFunction.LAST_VALUE): + new_instance = custom_grain_instance_from_parent.with_new_spec( + new_spec=custom_grain_instance_from_parent.spec.with_window_function(window_func), + column_association_resolver=self._column_association_resolver, + ) + select_column = SqlSelectColumn( + expr=SqlWindowFunctionExpression.create( + sql_function=window_func, + sql_function_args=(base_column_expr,), + partition_by_args=(custom_column_expr,), + order_by_args=(SqlWindowOrderByArgument(base_column_expr),), + ), + column_alias=new_instance.associated_column.column_name, + ) + new_instances += (new_instance,) + new_select_columns += (select_column,) + + # Build a column that tracks the row number for the base grain column within the custom grain period. + # This will be offset by 1 to represent the number of base grain periods since the start of the custom grain period. + # Ex: "ROW_NUMBER() OVER (PARTITION BY martian_day ORDER BY ds) AS ds__day__row_number" + new_instance = base_grain_instance_from_parent.with_new_spec( + new_spec=base_grain_instance_from_parent.spec.with_window_function(SqlWindowFunction.ROW_NUMBER), + column_association_resolver=self._column_association_resolver, + ) + window_func_expr = SqlWindowFunctionExpression.create( + sql_function=SqlWindowFunction.ROW_NUMBER, + partition_by_args=(custom_column_expr,), + order_by_args=(SqlWindowOrderByArgument(base_column_expr),), + ) + new_select_column = SqlSelectColumn( + expr=window_func_expr, + column_alias=new_instance.associated_column.column_name, + ) + new_instances += (new_instance,) + new_select_columns += (new_select_column,) + + return SqlDataSet( + instance_set=InstanceSet.merge([InstanceSet(time_dimension_instances=new_instances), parent_instance_set]), + sql_select_node=SqlSelectStatementNode.create( + description=node.description, + select_columns=parent_data_set.checked_sql_select_node.select_columns + new_select_columns, + from_source=parent_data_set.checked_sql_select_node, + from_source_alias=parent_data_set_alias, + ), + ) + + def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularityNode) -> SqlDataSet: + """For a given custom grain, offset its base grain by the requested number of custom grain periods. + + Example: if the custom grain is `fiscal_quarter` with a base grain of DAY and we're offsetting by 1 period, the + output SQL should look something like this: + + SELECT + CASE + WHEN DATEADD(day, ds__day__row_number - 1, ds__fiscal_quarter__first_value__offset) <= ds__fiscal_quarter__last_value__offset + THEN DATEADD(day, ds__day__row_number - 1, ds__fiscal_quarter__first_value__offset) + ELSE ds__fiscal_quarter__last_value__offset + END AS date_day + FROM custom_granularity_bounds_node + INNER JOIN filter_elements_node ON filter_elements_node.fiscal_quarter = custom_granularity_bounds_node.fiscal_quarter + """ + bounds_data_set = node.custom_granularity_bounds_node.accept(self) + bounds_instance_set = bounds_data_set.instance_set + bounds_data_set_alias = self._next_unique_table_alias() + filter_elements_data_set = node.filter_elements_node.accept(self) + filter_elements_instance_set = filter_elements_data_set.instance_set + filter_elements_data_set_alias = self._next_unique_table_alias() + offset_window = node.offset_window + custom_grain_name = offset_window.granularity + base_grain = ExpandedTimeGranularity.from_time_granularity( + self._get_time_spine_for_custom_granularity(custom_grain_name).base_granularity + ) + + # Find the required instances in the parent data sets. + first_value_instance: Optional[TimeDimensionInstance] = None + last_value_instance: Optional[TimeDimensionInstance] = None + row_number_instance: Optional[TimeDimensionInstance] = None + custom_grain_instance: Optional[TimeDimensionInstance] = None + base_grain_instance: Optional[TimeDimensionInstance] = None + for instance in filter_elements_instance_set.time_dimension_instances: + if instance.spec.window_function is SqlWindowFunction.FIRST_VALUE: + first_value_instance = instance + elif instance.spec.window_function is SqlWindowFunction.LAST_VALUE: + last_value_instance = instance + elif instance.spec.time_granularity.name == custom_grain_name: + custom_grain_instance = instance + if custom_grain_instance and first_value_instance and last_value_instance: + break + for instance in bounds_instance_set.time_dimension_instances: + if instance.spec.window_function is SqlWindowFunction.ROW_NUMBER: + row_number_instance = instance + elif instance.spec.time_granularity == base_grain and instance.spec.date_part is None: + base_grain_instance = instance + if base_grain_instance and row_number_instance: + break + assert ( + custom_grain_instance + and base_grain_instance + and first_value_instance + and last_value_instance + and row_number_instance + ), ( + "Did not find all required time dimension instances in parent data sets for OffsetByCustomGranularityNode. " + f"This indicates internal misconfiguration. Got custom grain instance: {custom_grain_instance}; base grain " + f"instance: {base_grain_instance}; first value instance: {first_value_instance}; last value instance: " + f"{last_value_instance}; row number instance: {row_number_instance}\n" + f"Available instances:{bounds_instance_set.time_dimension_instances}." + ) + + # First, build a subquery that offsets the first and last value columns. + custom_grain_column_name = custom_grain_instance.associated_column.column_name + custom_grain_column = SqlSelectColumn.from_table_and_column_names( + column_name=custom_grain_column_name, table_alias=filter_elements_data_set_alias + ) + first_value_offset_column, last_value_offset_column = tuple( + SqlSelectColumn( + expr=SqlWindowFunctionExpression.create( + sql_function=SqlWindowFunction.LEAD, + sql_function_args=( + SqlColumnReferenceExpression.from_table_and_column_names( + column_name=instance.associated_column.column_name, + table_alias=filter_elements_data_set_alias, + ), + SqlIntegerExpression.create(node.offset_window.count), + ), + order_by_args=(SqlWindowOrderByArgument(custom_grain_column.expr),), + ), + column_alias=f"{instance.associated_column.column_name}{DUNDER}offset", + ) + for instance in (first_value_instance, last_value_instance) + ) + offset_bounds_subquery_alias = self._next_unique_table_alias() + offset_bounds_subquery = SqlSelectStatementNode.create( + description="Offset Custom Granularity Bounds", + select_columns=(custom_grain_column, first_value_offset_column, last_value_offset_column), + from_source=filter_elements_data_set.checked_sql_select_node, + from_source_alias=filter_elements_data_set_alias, + ) + offset_bounds_subquery_alias = self._next_unique_table_alias() + + # Offset the base column by the requested window. If the offset date is not within the offset custom grain period, + # default to the last value in that period. + first_value_offset_expr, last_value_offset_expr = [ + SqlColumnReferenceExpression.from_table_and_column_names( + column_name=offset_column.column_alias, table_alias=offset_bounds_subquery_alias + ) + for offset_column in (first_value_offset_column, last_value_offset_column) + ] + offset_base_grain_expr = SqlAddTimeExpression.create( + arg=first_value_offset_expr, + count_expr=SqlArithmeticExpression.create( + left_expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=bounds_data_set_alias, column_name=row_number_instance.associated_column.column_name + ), + operator=SqlArithmeticOperator.SUBTRACT, + right_expr=SqlIntegerExpression.create(1), + ), + granularity=base_grain.base_granularity, + ) + is_below_last_value_expr = SqlComparisonExpression.create( + left_expr=offset_base_grain_expr, + comparison=SqlComparison.LESS_THAN_OR_EQUALS, + right_expr=last_value_offset_expr, + ) + offset_base_instance = base_grain_instance.with_new_spec( + # LEAD isn't quite accurate here, but this will differentiate the offset instance (and column) from the original one. + new_spec=base_grain_instance.spec.with_window_function(SqlWindowFunction.LEAD), + column_association_resolver=self._column_association_resolver, + ) + offset_base_column = SqlSelectColumn( + expr=SqlCaseExpression.create( + when_to_then_exprs={is_below_last_value_expr: offset_base_grain_expr}, + else_expr=last_value_offset_expr, + ), + column_alias=offset_base_instance.associated_column.column_name, + ) + original_base_grain_column = SqlSelectColumn.from_table_and_column_names( + column_name=base_grain_instance.associated_column.column_name, table_alias=bounds_data_set_alias + ) + join_desc = SqlJoinDescription( + right_source=offset_bounds_subquery, + right_source_alias=offset_bounds_subquery_alias, + join_type=SqlJoinType.INNER, + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=bounds_data_set_alias, column_name=custom_grain_column_name + ), + comparison=SqlComparison.EQUALS, + right_expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=offset_bounds_subquery_alias, column_name=custom_grain_column_name + ), + ), + ) + offset_base_grain_subquery = SqlSelectStatementNode.create( + description=node.description, + select_columns=(original_base_grain_column, offset_base_column), + from_source=bounds_data_set.checked_sql_select_node, + from_source_alias=bounds_data_set_alias, + join_descs=(join_desc,), + ) + offset_base_grain_subquery_alias = self._next_unique_table_alias() + + # Apply standard grains & date parts requested in the query. Use base grain for any custom grains. + standard_grain_instances: Tuple[TimeDimensionInstance, ...] = () + standard_grain_columns: Tuple[SqlSelectColumn, ...] = () + offset_base_column_ref = SqlSelectColumn( + expr=SqlColumnReferenceExpression.from_table_and_column_names( + column_name=offset_base_instance.associated_column.column_name, + table_alias=offset_base_grain_subquery_alias, + ), + column_alias=base_grain_instance.associated_column.column_name, + ) + for spec in node.required_time_spine_specs: + new_instance = base_grain_instance.with_new_spec( + new_spec=spec, column_association_resolver=self._column_association_resolver + ) + standard_grain_instances += (new_instance,) + if spec.date_part: + expr: SqlExpressionNode = SqlExtractExpression.create( + date_part=spec.date_part, arg=offset_base_column_ref.expr + ) + elif spec.time_granularity.base_granularity == base_grain.base_granularity: + expr = offset_base_column_ref.expr + else: + expr = SqlDateTruncExpression.create( + time_granularity=spec.time_granularity.base_granularity, arg=offset_base_column_ref.expr + ) + standard_grain_columns += ( + SqlSelectColumn(expr=expr, column_alias=new_instance.associated_column.column_name), + ) + + # Need to keep the non-offset base grain column in the output. This will be used to join to the source data set. + non_offset_base_grain_column = SqlSelectColumn.from_table_and_column_names( + column_name=base_grain_instance.associated_column.column_name, table_alias=offset_base_grain_subquery_alias + ) + + return SqlDataSet( + instance_set=InstanceSet(time_dimension_instances=(base_grain_instance,) + standard_grain_instances), + sql_select_node=SqlSelectStatementNode.create( + description="Apply Requested Granularities", + select_columns=(non_offset_base_grain_column,) + standard_grain_columns, + from_source=offset_base_grain_subquery, + from_source_alias=offset_base_grain_subquery_alias, + ), + ) + class DataflowNodeToSqlCteVisitor(DataflowNodeToSqlSubqueryVisitor): """Similar to `DataflowNodeToSqlSubqueryVisitor`, except that this converts specific nodes to CTEs. @@ -2210,5 +2502,17 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod def visit_alias_specs_node(self, node: AliasSpecsNode) -> SqlDataSet: # noqa: D102 return self._default_handler(node=node, node_to_select_subquery_function=super().visit_alias_specs_node) + @override + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> SqlDataSet: # noqa: D102 + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_custom_granularity_bounds_node + ) + + @override + def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularityNode) -> SqlDataSet: # noqa: D102 + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_offset_by_custom_granularity_node + ) + DataflowNodeT = TypeVar("DataflowNodeT", bound=DataflowPlanNode) diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index ff0b34c65..85e79b144 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -9,7 +9,7 @@ from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagId, DagNode, DisplayedProperty, MetricFlowDag -from metricflow_semantics.sql.sql_exprs import SqlExpressionNode +from metricflow_semantics.sql.sql_exprs import SqlColumnReferenceExpression, SqlExpressionNode from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.visitor import VisitorOutputT @@ -102,6 +102,16 @@ class SqlSelectColumn: # Always require a column alias for simplicity. column_alias: str + @staticmethod + def from_table_and_column_names(table_alias: str, column_name: str) -> SqlSelectColumn: + """Create a column that selects a column from a table by name.""" + return SqlSelectColumn( + expr=SqlColumnReferenceExpression.from_table_and_column_names( + column_name=column_name, table_alias=table_alias + ), + column_alias=column_name, + ) + @dataclass(frozen=True) class SqlJoinDescription: diff --git a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py index 05770806a..a66e5a9e5 100644 --- a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py +++ b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py @@ -24,6 +24,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -32,6 +33,7 @@ from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode +from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode @@ -114,6 +116,12 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod def visit_alias_specs_node(self, node: AliasSpecsNode) -> int: # noqa: D102 return self._sum_parents(node) + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> int: # noqa: D102 + return self._sum_parents(node) + + def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularityNode) -> int: # noqa: D102 + return self._sum_parents(node) + def count_source_nodes(self, dataflow_plan: DataflowPlan) -> int: # noqa: D102 return dataflow_plan.sink_node.accept(self)