Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Dec 11, 2024
1 parent a3129fc commit c907b88
Show file tree
Hide file tree
Showing 12 changed files with 326 additions and 61 deletions.
2 changes: 1 addition & 1 deletion metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,7 +1913,7 @@ def _build_time_spine_node(
pass
else:
time_spine_node: DataflowPlanNode = CustomGranularityBoundsNode.create(
parent_node=time_spine_read_node, custom_granularity_name=offset_window.granularity
parent_node=time_spine_read_node, offset_window=offset_window
)
# # need to add a property to these specs to indicate that they are offset or bounds or something
# filtered_bounds_node = FilterElementsNode.create(
Expand Down
9 changes: 9 additions & 0 deletions metricflow/dataflow/dataflow_plan_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand Down Expand Up @@ -126,6 +127,10 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError


class DataflowPlanNodeVisitorWithDefaultHandler(DataflowPlanNodeVisitor[VisitorOutputT], Generic[VisitorOutputT]):
"""Similar to `DataflowPlanNodeVisitor`, but with an abstract default handler that gets called for each node.
Expand Down Expand Up @@ -222,3 +227,7 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
@override
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)

@override
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)
22 changes: 8 additions & 14 deletions metricflow/dataflow/nodes/custom_granularity_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import Sequence

from dbt_semantic_interfaces.protocols.metric import MetricTimeWindow
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.visitor import VisitorOutputT
Expand All @@ -12,28 +13,28 @@
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


# TODO: rename node & file probably & docstring
@dataclass(frozen=True, eq=False)
class CustomGranularityBoundsNode(DataflowPlanNode, ABC):
"""Calculate the start and end of a custom granularity period and each row number within that period."""

custom_granularity_name: str
offset_window: MetricTimeWindow

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.parent_nodes) == 1

@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode, custom_granularity_name: str
parent_node: DataflowPlanNode, offset_window: MetricTimeWindow
) -> CustomGranularityBoundsNode:
return CustomGranularityBoundsNode(parent_nodes=(parent_node,), custom_granularity_name=custom_granularity_name)
return CustomGranularityBoundsNode(parent_nodes=(parent_node,), offset_window=offset_window)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_CUSTOM_GRANULARITY_BOUNDS_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
# Type checking not working here
return visitor.visit_custom_granularity_bounds_node(self)

@property
Expand All @@ -42,24 +43,17 @@ def description(self) -> str: # noqa: D102

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("custom_granularity_name", self.custom_granularity_name),
)
return tuple(super().displayed_properties) + (DisplayedProperty("offset_window", self.offset_window),)

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
return self.parent_nodes[0]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return (
isinstance(other_node, self.__class__)
and other_node.custom_granularity_name == self.custom_granularity_name
)
return isinstance(other_node, self.__class__) and other_node.offset_window == self.offset_window

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> CustomGranularityBoundsNode:
assert len(new_parent_nodes) == 1
return CustomGranularityBoundsNode.create(
parent_node=new_parent_nodes[0], custom_granularity_name=self.custom_granularity_name
)
return CustomGranularityBoundsNode.create(parent_node=new_parent_nodes[0], offset_window=self.offset_window)
6 changes: 6 additions & 0 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand Down Expand Up @@ -472,6 +473,11 @@ def visit_join_to_custom_granularity_node( # noqa: D102
def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102
raise NotImplementedError

def visit_custom_granularity_bounds_node( # noqa: D102
self, node: CustomGranularityBoundsNode
) -> OptimizeBranchResult:
raise NotImplementedError

def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult:
"""Handles pushdown state propagation for the standard join node type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand Down Expand Up @@ -472,3 +473,9 @@ def visit_min_max_node(self, node: MinMaxNode) -> ComputeMetricsBranchCombinerRe
def visit_alias_specs_node(self, node: AliasSpecsNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_custom_granularity_bounds_node( # noqa: D102
self, node: CustomGranularityBoundsNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand Down Expand Up @@ -356,3 +357,9 @@ def visit_min_max_node(self, node: MinMaxNode) -> OptimizeBranchResult: # noqa:
def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_custom_granularity_bounds_node( # noqa: D102
self, node: CustomGranularityBoundsNode
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_base_output_handler(node)
5 changes: 5 additions & 0 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand Down Expand Up @@ -205,3 +206,7 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
@override
def visit_alias_specs_node(self, node: AliasSpecsNode) -> ConvertToExecutionPlanResult:
raise NotImplementedError

@override
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> ConvertToExecutionPlanResult:
raise NotImplementedError
Loading

0 comments on commit c907b88

Please sign in to comment.