Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Dec 17, 2024
1 parent 9e15172 commit c30730e
Show file tree
Hide file tree
Showing 12 changed files with 205 additions and 117 deletions.
15 changes: 3 additions & 12 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1928,17 +1928,10 @@ def _build_time_spine_node(
custom_granularity_name=custom_grain.name,
)
time_spine_node: DataflowPlanNode = OffsetByCustomGranularityNode.create(
parent_node=bounds_node, offset_window=offset_window
parent_node=bounds_node,
offset_window=offset_window,
required_time_spine_specs=required_time_spine_specs,
)
# if queried_standard_specs:
# # TODO: This is also when we can change the alias names to match the requested specs
# time_spine_node = ApplyStandardGranularityNode.create(
# parent_node=time_spine_node, time_dimension_specs=queried_standard_specs
# )
for custom_spec in queried_custom_specs:
time_spine_node = JoinToCustomGranularityNode.create(
parent_node=time_spine_node, time_dimension_spec=custom_spec
)
else:
# TODO: support multiple time spines here. Build node on the one with the smallest base grain.
# Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine.
Expand All @@ -1965,8 +1958,6 @@ def _build_time_spine_node(
spec.time_granularity for spec in queried_time_spine_specs
}

# -- JoinToCustomGranularityNode -- if needed to support another custom grain not covered by initial time spine

return self._build_pre_aggregation_plan(
source_node=time_spine_node,
filter_to_specs=InstanceSpecSet(time_dimension_specs=tuple(queried_time_spine_specs)),
Expand Down
13 changes: 13 additions & 0 deletions metricflow/dataflow/dataflow_plan_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -131,6 +132,12 @@ def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noq
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> VisitorOutputT:
raise NotImplementedError


class DataflowPlanNodeVisitorWithDefaultHandler(DataflowPlanNodeVisitor[VisitorOutputT], Generic[VisitorOutputT]):
"""Similar to `DataflowPlanNodeVisitor`, but with an abstract default handler that gets called for each node.
Expand Down Expand Up @@ -231,3 +238,9 @@ def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noq
@override
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)

@override
def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> VisitorOutputT:
return self._default_handler(node)
33 changes: 25 additions & 8 deletions metricflow/dataflow/nodes/offset_by_custom_granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
from dataclasses import dataclass
from typing import Sequence

from dbt_semantic_interfaces.protocols.metric import MetricTimeWindow
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from dbt_semantic_interfaces.protocols.metric import MetricTimeWindow

stuff


@dataclass(frozen=True, eq=False)
Expand All @@ -24,6 +23,7 @@ class OffsetByCustomGranularityNode(DataflowPlanNode, ABC):
"""

offset_window: MetricTimeWindow
required_time_spine_specs: Sequence[TimeDimensionSpec]

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
Expand All @@ -35,9 +35,15 @@ def __post_init__(self) -> None: # noqa: D105

@staticmethod
def create( # noqa: D102
parent_node: CustomGranularityBoundsNode, offset_window: MetricTimeWindow
parent_node: CustomGranularityBoundsNode,
offset_window: MetricTimeWindow,
required_time_spine_specs: Sequence[TimeDimensionSpec],
) -> OffsetByCustomGranularityNode:
return OffsetByCustomGranularityNode(parent_nodes=(parent_node,), offset_window=offset_window)
return OffsetByCustomGranularityNode(
parent_nodes=(parent_node,),
offset_window=offset_window,
required_time_spine_specs=required_time_spine_specs,
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
Expand All @@ -52,17 +58,28 @@ def description(self) -> str: # noqa: D102

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (DisplayedProperty("offset_window", self.offset_window),)
return tuple(super().displayed_properties) + (
DisplayedProperty("offset_window", self.offset_window),
DisplayedProperty("required_time_spine_specs", self.required_time_spine_specs),
)

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
return self.parent_nodes[0]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return isinstance(other_node, self.__class__) and other_node.offset_window == self.offset_window
return (
isinstance(other_node, self.__class__)
and other_node.offset_window == self.offset_window
and other_node.required_time_spine_specs == self.required_time_spine_specs
)

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> OffsetByCustomGranularityNode:
assert len(new_parent_nodes) == 1
return OffsetByCustomGranularityNode.create(parent_node=new_parent_nodes[0], offset_window=self.offset_window)
return OffsetByCustomGranularityNode(
parent_nodes=tuple(new_parent_nodes),
offset_window=self.offset_window,
required_time_spine_specs=self.required_time_spine_specs,
)
6 changes: 6 additions & 0 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -478,6 +479,11 @@ def visit_custom_granularity_bounds_node( # noqa: D102
) -> OptimizeBranchResult:
raise NotImplementedError

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> OptimizeBranchResult:
raise NotImplementedError

def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult:
"""Handles pushdown state propagation for the standard join node type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -479,3 +480,9 @@ def visit_custom_granularity_bounds_node( # noqa: D102
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -370,3 +371,9 @@ def visit_custom_granularity_bounds_node( # noqa: D102
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_base_output_handler(node)
7 changes: 7 additions & 0 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -210,3 +211,9 @@ def visit_alias_specs_node(self, node: AliasSpecsNode) -> ConvertToExecutionPlan
@override
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> ConvertToExecutionPlanResult:
raise NotImplementedError

@override
def visit_offset_by_custom_granularity_node(
self, node: OffsetByCustomGranularityNode
) -> ConvertToExecutionPlanResult:
raise NotImplementedError
66 changes: 54 additions & 12 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,9 +1273,9 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr
spec=metric_time_dimension_spec,
)
)
output_column_to_input_column[
metric_time_dimension_column_association.column_name
] = matching_time_dimension_instance.associated_column.column_name
output_column_to_input_column[metric_time_dimension_column_association.column_name] = (
matching_time_dimension_instance.associated_column.column_name
)

output_instance_set = InstanceSet(
measure_instances=tuple(output_measure_instances),
Expand Down Expand Up @@ -2067,7 +2067,7 @@ def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode
# This will be offset by 1 to represent the number of base grain periods since the start of the custom grain period.
# Ex: "ROW_NUMBER() OVER (PARTITION BY martian_day ORDER BY ds) AS ds__day__row_number"
new_instance = base_grain_instance_from_parent.with_new_spec(
new_spec=base_grain_instance_from_parent.spec.with_window_function(window_func),
new_spec=base_grain_instance_from_parent.spec.with_window_function(SqlWindowFunction.ROW_NUMBER),
column_association_resolver=self._column_association_resolver,
)
window_func_expr = SqlWindowFunctionExpression.create(
Expand Down Expand Up @@ -2108,7 +2108,7 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit
WHEN DATEADD(day, ds__day__row_number - 1, ds__fiscal_quarter__first_value__offset) <= ds__fiscal_quarter__last_value__offset
THEN DATEADD(day, ds__day__row_number - 1, ds__fiscal_quarter__first_value__offset)
ELSE ds__fiscal_quarter__last_value__offset
END AS date_day__offset
END AS date_day
FROM custom_granularity_bounds_node
INNER JOIN (
SELECT
Expand Down Expand Up @@ -2152,7 +2152,7 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit
row_number_instance = instance
elif instance.spec.time_granularity.name == custom_grain_name:
custom_grain_instance = instance
elif instance.spec.time_granularity == base_grain:
elif instance.spec.time_granularity == base_grain and instance.spec.date_part is None:
base_grain_instance = instance
if (
custom_grain_instance
Expand Down Expand Up @@ -2266,14 +2266,50 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit
),
),
)
offset_base_grain_subquery = SqlSelectStatementNode.create(
description=node.description,
select_columns=(new_custom_grain_column, offset_base_column),
from_source=parent_data_set.checked_sql_select_node,
from_source_alias=parent_data_set_alias,
join_descs=(join_desc,),
)
offset_base_grain_subquery_alias = self._next_unique_table_alias()

# Apply standard grains & date parts requested in the query. Use base grain for any custom grains.
standard_grain_instances: Tuple[TimeDimensionInstance, ...] = ()
standard_grain_columns: Tuple[SqlSelectColumn, ...] = ()
base_column = SqlSelectColumn(
expr=SqlColumnReferenceExpression.from_table_and_column_names(
column_name=base_grain_instance.associated_column.column_name,
table_alias=offset_base_grain_subquery_alias,
),
column_alias=base_grain_instance.associated_column.column_name,
)
for spec in node.required_time_spine_specs:
standard_grain_instances += (
base_grain_instance.with_new_spec(
new_spec=spec, column_association_resolver=self._column_association_resolver
),
)
if spec.date_part:
expr: SqlExpressionNode = SqlExtractExpression.create(date_part=spec.date_part, arg=base_column.expr)
elif spec.time_granularity == base_grain:
expr = base_column.expr
else:
expr = SqlDateTruncExpression.create(
time_granularity=spec.time_granularity.base_granularity, arg=base_column.expr
)
standard_grain_columns += (SqlSelectColumn(expr=expr, column_alias=instance.associated_column.column_name),)

return SqlDataSet(
instance_set=InstanceSet(time_dimension_instances=(custom_grain_instance, base_grain_instance)),
instance_set=InstanceSet(
time_dimension_instances=(custom_grain_instance, base_grain_instance) + standard_grain_instances
),
sql_select_node=SqlSelectStatementNode.create(
description=node.description,
select_columns=(new_custom_grain_column, offset_base_column),
from_source=parent_data_set.checked_sql_select_node,
from_source_alias=parent_data_set_alias,
join_descs=(join_desc,),
description="Apply Requested Granularities",
select_columns=(base_column,) + standard_grain_columns,
from_source=offset_base_grain_subquery,
from_source_alias=offset_base_grain_subquery_alias,
),
)

Expand Down Expand Up @@ -2478,5 +2514,11 @@ def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode
node=node, node_to_select_subquery_function=super().visit_custom_granularity_bounds_node
)

@override
def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularityNode) -> SqlDataSet: # noqa: D102
return self._default_handler(
node=node, node_to_select_subquery_function=super().visit_offset_by_custom_granularity_node
)


DataflowNodeT = TypeVar("DataflowNodeT", bound=DataflowPlanNode)
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -118,6 +119,9 @@ def visit_alias_specs_node(self, node: AliasSpecsNode) -> int: # noqa: D102
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> int: # noqa: D102
return self._sum_parents(node)

def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularityNode) -> int: # noqa: D102
return self._sum_parents(node)

def count_source_nodes(self, dataflow_plan: DataflowPlan) -> int: # noqa: D102
return dataflow_plan.sink_node.accept(self)

Expand Down
Loading

0 comments on commit c30730e

Please sign in to comment.