Skip to content

Commit

Permalink
Merge pull request #768 from dbt-labs/remove-generic-dataset-constructs
Browse files Browse the repository at this point in the history
Remove generic source dataset constructs
  • Loading branch information
tlento authored Sep 8, 2023
2 parents c00d603 + f20fadf commit 87ee477
Show file tree
Hide file tree
Showing 32 changed files with 601 additions and 787 deletions.
48 changes: 22 additions & 26 deletions metricflow/dataflow/builder/costing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Sequence
from typing import Any, Sequence

from metricflow.dataflow.dataflow_plan import (
AggregateMeasuresNode,
Expand All @@ -29,7 +29,6 @@
OrderByLimitNode,
ReadSqlSourceNode,
SemiAdditiveJoinNode,
SourceDataSetT,
WhereConstraintNode,
WriteToResultDataframeNode,
WriteToResultTableNode,
Expand Down Expand Up @@ -71,38 +70,37 @@ def sum(costs: Sequence[DefaultCost]) -> DefaultCost: # noqa: D
)


class DataflowPlanNodeCostFunction(Generic[SourceDataSetT], ABC):
class DataflowPlanNodeCostFunction(ABC):
"""A function that calculates the cost for computing the dataflow up to a given node."""

@abstractmethod
def calculate_cost(self, node: DataflowPlanNode[SourceDataSetT]) -> DataflowPlanNodeCost:
def calculate_cost(self, node: DataflowPlanNode) -> DataflowPlanNodeCost:
"""Return the cost for calculating the given dataflow up to the given node."""
pass


class DefaultCostFunction(
Generic[SourceDataSetT],
DataflowPlanNodeCostFunction[SourceDataSetT],
DataflowPlanNodeVisitor[SourceDataSetT, DefaultCost],
DataflowPlanNodeCostFunction,
DataflowPlanNodeVisitor[DefaultCost],
):
"""Cost function using the default cost."""

def calculate_cost(self, node: DataflowPlanNode[SourceDataSetT]) -> DataflowPlanNodeCost: # noqa: D
def calculate_cost(self, node: DataflowPlanNode) -> DataflowPlanNodeCost: # noqa: D
return node.accept(self)

def visit_source_node(self, node: ReadSqlSourceNode[SourceDataSetT]) -> DefaultCost: # noqa: D
def visit_source_node(self, node: ReadSqlSourceNode) -> DefaultCost: # noqa: D
# Base case.
return DefaultCost(num_joins=0, num_aggregations=0)

def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode[SourceDataSetT]) -> DefaultCost: # noqa: D
def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> DefaultCost: # noqa: D
parent_costs = [x.accept(self) for x in node.parent_nodes]

# Add number of joins to the cost.
node_cost = DefaultCost(num_joins=len(node.join_targets))
return DefaultCost.sum(parent_costs + [node_cost])

def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D
self, node: JoinAggregatedMeasuresByGroupByColumnsNode[SourceDataSetT]
self, node: JoinAggregatedMeasuresByGroupByColumnsNode
) -> DefaultCost:
parent_costs = [x.accept(self) for x in node.parent_nodes]

Expand All @@ -111,57 +109,55 @@ def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D
node_cost = DefaultCost(num_joins=num_joins)
return DefaultCost.sum(parent_costs + [node_cost])

def visit_aggregate_measures_node(self, node: AggregateMeasuresNode[SourceDataSetT]) -> DefaultCost: # noqa: D
def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> DefaultCost: # noqa: D
parent_costs = [x.accept(self) for x in node.parent_nodes]

# Add the number of aggregations to the cost
node_cost = DefaultCost(num_aggregations=1)
return DefaultCost.sum(parent_costs + [node_cost])

def visit_compute_metrics_node(self, node: ComputeMetricsNode[SourceDataSetT]) -> DefaultCost: # noqa: D
def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> DefaultCost: # noqa: D
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes])

def visit_order_by_limit_node(self, node: OrderByLimitNode[SourceDataSetT]) -> DefaultCost: # noqa: D
def visit_order_by_limit_node(self, node: OrderByLimitNode) -> DefaultCost: # noqa: D
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes])

def visit_where_constraint_node(self, node: WhereConstraintNode[SourceDataSetT]) -> DefaultCost: # noqa: D
def visit_where_constraint_node(self, node: WhereConstraintNode) -> DefaultCost: # noqa: D
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes])

def visit_write_to_result_dataframe_node( # noqa: D
self, node: WriteToResultDataframeNode[SourceDataSetT]
) -> DefaultCost:
def visit_write_to_result_dataframe_node(self, node: WriteToResultDataframeNode) -> DefaultCost: # noqa: D
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes])

def visit_write_to_result_table_node(self, node: WriteToResultTableNode[SourceDataSetT]) -> DefaultCost: # noqa: D
def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> DefaultCost: # noqa: D
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes])

def visit_pass_elements_filter_node(self, node: FilterElementsNode[SourceDataSetT]) -> DefaultCost: # noqa: D
def visit_pass_elements_filter_node(self, node: FilterElementsNode) -> DefaultCost: # noqa: D
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes])

def visit_combine_metrics_node(self, node: CombineMetricsNode[SourceDataSetT]) -> DefaultCost: # noqa: D
def visit_combine_metrics_node(self, node: CombineMetricsNode) -> DefaultCost: # noqa: D
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes])

def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode[SourceDataSetT]) -> DefaultCost: # noqa: D
def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> DefaultCost: # noqa: D
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes])

def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode[SourceDataSetT]) -> DefaultCost: # noqa: D
def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> DefaultCost: # noqa: D
parent_costs = [x.accept(self) for x in node.parent_nodes]

# Add the number of aggregations to the cost (eg 1 per unit time)
node_cost = DefaultCost(num_aggregations=1)
return DefaultCost.sum(parent_costs + [node_cost])

def visit_metric_time_dimension_transform_node( # noqa: D
self, node: MetricTimeDimensionTransformNode[SourceDataSetT]
self, node: MetricTimeDimensionTransformNode
) -> DefaultCost:
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes])

def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode[SourceDataSetT]) -> DefaultCost: # noqa: D
def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> DefaultCost: # noqa: D
parent_costs = [x.accept(self) for x in node.parent_nodes]

# Add number of joins to the cost.
node_cost = DefaultCost(num_joins=1)
return DefaultCost.sum(parent_costs + [node_cost])

def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode[SourceDataSetT]) -> DefaultCost: # noqa: D
def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> DefaultCost: # noqa: D
return DefaultCost.sum([x.accept(self) for x in node.parent_nodes] + [DefaultCost(num_joins=1)])
Loading

0 comments on commit 87ee477

Please sign in to comment.