diff --git a/metricflow/dataflow/builder/costing.py b/metricflow/dataflow/builder/costing.py deleted file mode 100644 index b7c47a0faf..0000000000 --- a/metricflow/dataflow/builder/costing.py +++ /dev/null @@ -1,167 +0,0 @@ -"""This module helps to figure out the computational cost for executing a dataflow plan. - -There may be multiple possible dataflow plans to realize a set of measures and dimensions (or rather any set of metric -definition instances) because data sets could include an overlapping set of measures and dimensions. - -Knowing the cost of a dataflow plan can be used to order the possible plans for optimal execution. -""" - - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Sequence - -from metricflow.dataflow.dataflow_plan import ( - AggregateMeasuresNode, - CombineMetricsNode, - ComputeMetricsNode, - ConstrainTimeRangeNode, - DataflowPlanNode, - DataflowPlanNodeVisitor, - FilterElementsNode, - JoinAggregatedMeasuresByGroupByColumnsNode, - JoinOverTimeRangeNode, - JoinToBaseOutputNode, - JoinToTimeSpineNode, - MetricTimeDimensionTransformNode, - OrderByLimitNode, - ReadSqlSourceNode, - SemiAdditiveJoinNode, - WhereConstraintNode, - WriteToResultDataframeNode, - WriteToResultTableNode, -) - - -class DataflowPlanNodeCost(ABC): - """Represents the cost to compute the data flow up to a given node.""" - - def __lt__(self, other: Any) -> bool: # type: ignore - """Implement < so that lists with this can be sorted.""" - if not isinstance(other, DataflowPlanNodeCost): - return NotImplemented - return self.as_int < other.as_int - - @property - @abstractmethod - def as_int(self) -> int: - """The cost as an integer for ordering.""" - pass - - -@dataclass(frozen=True) -class DefaultCost(DataflowPlanNodeCost): - """Simple cost model where the cost is the number joins * 10 + the number of aggregations.""" - - num_joins: int = 0 - num_aggregations: int = 0 - - @property - def as_int(self) -> int: # noqa: D - return self.num_joins * 10 + self.num_aggregations - - @staticmethod - def sum(costs: Sequence[DefaultCost]) -> DefaultCost: # noqa: D - return DefaultCost( - num_joins=sum([x.num_joins for x in costs]), - num_aggregations=sum([x.num_aggregations for x in costs]), - ) - - -class DataflowPlanNodeCostFunction(ABC): - """A function that calculates the cost for computing the dataflow up to a given node.""" - - @abstractmethod - def calculate_cost(self, node: DataflowPlanNode) -> DataflowPlanNodeCost: - """Return the cost for calculating the given dataflow up to the given node.""" - pass - - -class DefaultCostFunction( - DataflowPlanNodeCostFunction, - DataflowPlanNodeVisitor[DefaultCost], -): - """Cost function using the default cost.""" - - def calculate_cost(self, node: DataflowPlanNode) -> DataflowPlanNodeCost: # noqa: D - return node.accept(self) - - def visit_source_node(self, node: ReadSqlSourceNode) -> DefaultCost: # noqa: D - # Base case. - return DefaultCost(num_joins=0, num_aggregations=0) - - def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> DefaultCost: # noqa: D - parent_costs = [x.accept(self) for x in node.parent_nodes] - - # Add number of joins to the cost. - node_cost = DefaultCost(num_joins=len(node.join_targets)) - return DefaultCost.sum(parent_costs + [node_cost]) - - def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D - self, node: JoinAggregatedMeasuresByGroupByColumnsNode - ) -> DefaultCost: - parent_costs = [x.accept(self) for x in node.parent_nodes] - - # This node does N-1 joins to link its N parents together - num_joins = len(node.parent_nodes) - 1 - node_cost = DefaultCost(num_joins=num_joins) - return DefaultCost.sum(parent_costs + [node_cost]) - - def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> DefaultCost: # noqa: D - parent_costs = [x.accept(self) for x in node.parent_nodes] - - # Add the number of aggregations to the cost - node_cost = DefaultCost(num_aggregations=1) - return DefaultCost.sum(parent_costs + [node_cost]) - - def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> DefaultCost: # noqa: D - return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - - def visit_order_by_limit_node(self, node: OrderByLimitNode) -> DefaultCost: # noqa: D - return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - - def visit_where_constraint_node(self, node: WhereConstraintNode) -> DefaultCost: # noqa: D - return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - - def visit_write_to_result_dataframe_node(self, node: WriteToResultDataframeNode) -> DefaultCost: # noqa: D - return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - - def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> DefaultCost: # noqa: D - return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - - def visit_pass_elements_filter_node(self, node: FilterElementsNode) -> DefaultCost: # noqa: D - parent_costs = [x.accept(self) for x in node.parent_nodes] - - # 1 aggregation if grouping by distinct values - node_cost = DefaultCost(num_aggregations=1 if node.distinct else 0) - return DefaultCost.sum(parent_costs + [node_cost]) - - def visit_combine_metrics_node(self, node: CombineMetricsNode) -> DefaultCost: # noqa: D - return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - - def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> DefaultCost: # noqa: D - return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - - def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> DefaultCost: # noqa: D - parent_costs = [x.accept(self) for x in node.parent_nodes] - - # Add the number of aggregations to the cost (eg 1 per unit time) - node_cost = DefaultCost(num_aggregations=1) - return DefaultCost.sum(parent_costs + [node_cost]) - - def visit_metric_time_dimension_transform_node( # noqa: D - self, node: MetricTimeDimensionTransformNode - ) -> DefaultCost: - return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - - def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> DefaultCost: # noqa: D - parent_costs = [x.accept(self) for x in node.parent_nodes] - - # Add number of joins to the cost. - node_cost = DefaultCost(num_joins=1) - return DefaultCost.sum(parent_costs + [node_cost]) - - def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> DefaultCost: # noqa: D - return DefaultCost.sum([x.accept(self) for x in node.parent_nodes] + [DefaultCost(num_joins=1)]) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index bfaa36a3cc..388c8879d4 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -15,7 +15,6 @@ from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow.dag.id_generation import DATAFLOW_PLAN_PREFIX, IdGeneratorRegistry -from metricflow.dataflow.builder.costing import DataflowPlanNodeCostFunction, DefaultCostFunction from metricflow.dataflow.builder.measure_additiveness import group_measure_specs_by_additiveness from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver from metricflow.dataflow.builder.node_evaluator import ( @@ -147,14 +146,12 @@ def __init__( # noqa: D source_nodes: Sequence[BaseOutput], read_nodes: Sequence[ReadSqlSourceNode], semantic_manifest_lookup: SemanticManifestLookup, - cost_function: DataflowPlanNodeCostFunction = DefaultCostFunction(), node_output_resolver: Optional[DataflowPlanNodeOutputDataSetResolver] = None, column_association_resolver: Optional[ColumnAssociationResolver] = None, ) -> None: self._semantic_model_lookup = semantic_manifest_lookup.semantic_model_lookup self._metric_lookup = semantic_manifest_lookup.metric_lookup self._metric_time_dimension_reference = DataSet.metric_time_dimension_reference() - self._cost_function = cost_function self._source_nodes = source_nodes self._read_nodes = read_nodes self._column_association_resolver = ( @@ -391,15 +388,14 @@ def _get_semantic_model_names_for_measures(self, measure_names: Sequence[Measure return semantic_model_names def _sort_by_suitability(self, nodes: Sequence[BaseOutput]) -> Sequence[BaseOutput]: - """Sort nodes by the cost, then by the number of linkable specs. + """Sort nodes by the number of linkable specs. - Lower cost nodes will result in faster queries, and the lower the number of linkable specs means less - aggregation required. + The lower the number of linkable specs means less aggregation required. """ - def sort_function(node: BaseOutput) -> Tuple[int, int]: + def sort_function(node: BaseOutput) -> int: data_set = self._node_data_set_resolver.get_output_data_set(node) - return self._cost_function.calculate_cost(node).as_int, len(data_set.instance_set.spec_set.linkable_specs) + return len(data_set.instance_set.spec_set.linkable_specs) return sorted(nodes, key=sort_function) @@ -591,7 +587,7 @@ def _find_dataflow_recipe( logger.info(f"Found {len(node_to_evaluation)} candidate source nodes.") if len(node_to_evaluation) > 0: - # All source nodes cost the same. Find evaluation with lowest number of joins. + # Find evaluation with lowest number of joins. node_with_lowest_cost_plan = min( node_to_evaluation, key=lambda node: len(node_to_evaluation[node].join_recipes) ) diff --git a/metricflow/test/dataflow/builder/test_costing.py b/metricflow/test/dataflow/builder/test_costing.py deleted file mode 100644 index fa4a1a1621..0000000000 --- a/metricflow/test/dataflow/builder/test_costing.py +++ /dev/null @@ -1,82 +0,0 @@ -from __future__ import annotations - -import logging - -from metricflow.dataflow.builder.costing import DefaultCost, DefaultCostFunction -from metricflow.dataflow.dataflow_plan import ( - AggregateMeasuresNode, - FilterElementsNode, - JoinDescription, - JoinToBaseOutputNode, -) -from metricflow.specs.specs import ( - DimensionSpec, - EntitySpec, - InstanceSpecSet, - LinklessEntitySpec, - MeasureSpec, - MetricInputMeasureSpec, -) -from metricflow.test.fixtures.model_fixtures import ConsistentIdObjectRepository - -logger = logging.getLogger(__name__) - - -def test_costing(consistent_id_object_repository: ConsistentIdObjectRepository) -> None: # noqa: D - bookings_node = consistent_id_object_repository.simple_model_read_nodes["bookings_source"] - listings_node = consistent_id_object_repository.simple_model_read_nodes["listings_latest"] - - bookings_spec = MeasureSpec( - element_name="bookings", - ) - bookings_filtered = FilterElementsNode( - parent_node=bookings_node, - include_specs=InstanceSpecSet( - measure_specs=(bookings_spec,), - entity_specs=( - EntitySpec( - element_name="listing", - entity_links=(), - ), - ), - ), - ) - - listings_filtered = FilterElementsNode( - parent_node=listings_node, - include_specs=InstanceSpecSet( - dimension_specs=( - DimensionSpec( - element_name="country_latest", - entity_links=(), - ), - ), - entity_specs=( - EntitySpec( - element_name="listing", - entity_links=(), - ), - ), - ), - ) - - join_node = JoinToBaseOutputNode( - left_node=bookings_filtered, - join_targets=[ - JoinDescription( - join_node=listings_filtered, - join_on_entity=LinklessEntitySpec.from_element_name("listing"), - join_on_partition_dimensions=(), - join_on_partition_time_dimensions=(), - ) - ], - ) - - bookings_aggregated = AggregateMeasuresNode( - parent_node=join_node, metric_input_measure_specs=(MetricInputMeasureSpec(measure_spec=bookings_spec),) - ) - - cost_function = DefaultCostFunction() - cost = cost_function.calculate_cost(bookings_aggregated) - - assert cost == DefaultCost(num_joins=1, num_aggregations=1) diff --git a/metricflow/test/dataflow/builder/test_cyclic_join.py b/metricflow/test/dataflow/builder/test_cyclic_join.py index d35a2b3987..f33be8a8d1 100644 --- a/metricflow/test/dataflow/builder/test_cyclic_join.py +++ b/metricflow/test/dataflow/builder/test_cyclic_join.py @@ -6,7 +6,6 @@ from _pytest.fixtures import FixtureRequest from dbt_semantic_interfaces.references import EntityReference -from metricflow.dataflow.builder.costing import DefaultCostFunction from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder from metricflow.dataflow.dataflow_plan_to_text import dataflow_plan_as_text from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup @@ -36,7 +35,6 @@ def cyclic_join_manifest_dataflow_plan_builder( # noqa: D source_nodes=consistent_id_object_repository.cyclic_join_source_nodes, read_nodes=list(consistent_id_object_repository.cyclic_join_read_nodes.values()), semantic_manifest_lookup=cyclic_join_semantic_manifest_lookup, - cost_function=DefaultCostFunction(), ) diff --git a/metricflow/test/fixtures/dataflow_fixtures.py b/metricflow/test/fixtures/dataflow_fixtures.py index e1af9103d2..d7bf7f89bc 100644 --- a/metricflow/test/fixtures/dataflow_fixtures.py +++ b/metricflow/test/fixtures/dataflow_fixtures.py @@ -2,7 +2,6 @@ import pytest -from metricflow.dataflow.builder.costing import DefaultCostFunction from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver @@ -36,7 +35,6 @@ def dataflow_plan_builder( # noqa: D source_nodes=consistent_id_object_repository.simple_model_source_nodes, read_nodes=list(consistent_id_object_repository.simple_model_read_nodes.values()), semantic_manifest_lookup=simple_semantic_manifest_lookup, - cost_function=DefaultCostFunction(), ) @@ -50,7 +48,6 @@ def multihop_dataflow_plan_builder( # noqa: D source_nodes=consistent_id_object_repository.multihop_model_source_nodes, read_nodes=list(consistent_id_object_repository.multihop_model_read_nodes.values()), semantic_manifest_lookup=multi_hop_join_semantic_manifest_lookup, - cost_function=DefaultCostFunction(), ) @@ -72,7 +69,6 @@ def scd_dataflow_plan_builder( # noqa: D source_nodes=consistent_id_object_repository.scd_model_source_nodes, read_nodes=list(consistent_id_object_repository.scd_model_read_nodes.values()), semantic_manifest_lookup=scd_semantic_manifest_lookup, - cost_function=DefaultCostFunction(), column_association_resolver=scd_column_association_resolver, )