From 57377a09e86bd80596f7df8db9883dd9c4e7f187 Mon Sep 17 00:00:00 2001 From: tlento Date: Thu, 7 Sep 2023 15:34:31 -0700 Subject: [PATCH 1/3] Remove generic source dataset constructs The original implementation of MetricFlow used a generic type annotation for the "Source DataSet" of the dataflow plan. This was in place to support some specific implementations of Transform's proprietary product logic, and as an interface boundary it proved ineffective. Consequently, attempts to revisit that functionality will rely on other approaches. Within MetricFlow, this generic type resolves, to one of two concrete types: the SqlDataSet, or the SemanticModelDataSet, which inherits from SqlDataSet. Since it's far simpler to just use SqlDataSet as the required type than to litter the codebase with generic typehints and concrete type specs in some places but not others, we make that change here. Sadly, this must all be done in one massive commit, because there's no way to get the typechecker to pass by removing this one stage at a time without doing considerably more restructuring around the data set handoffs. This commit was produced via a mixture of mechanical find/replace approaches within VSCode, along with file-by-file manual fixups for cases where leftovers were causing trouble. --- metricflow/dataflow/builder/costing.py | 48 ++- .../dataflow/builder/dataflow_plan_builder.py | 96 +++-- metricflow/dataflow/builder/node_data_set.py | 10 +- metricflow/dataflow/builder/node_evaluator.py | 13 +- metricflow/dataflow/builder/source_node.py | 10 +- metricflow/dataflow/dataflow_plan.py | 340 ++++++++---------- .../optimizer/dataflow_plan_optimizer.py | 7 +- .../source_scan/cm_branch_combiner.py | 89 ++--- .../source_scan/source_scan_optimizer.py | 122 +++---- metricflow/engine/metricflow_engine.py | 14 +- .../model/data_warehouse_model_validator.py | 11 +- .../plan_conversion/dataflow_to_execution.py | 19 +- metricflow/plan_conversion/dataflow_to_sql.py | 27 +- metricflow/plan_conversion/node_processor.py | 44 ++- .../plan_conversion/sql_join_builder.py | 6 +- metricflow/query/query_parser.py | 5 +- .../test/dataflow/builder/test_costing.py | 11 +- .../test/dataflow/builder/test_cyclic_join.py | 7 +- .../builder/test_dataflow_plan_builder.py | 49 ++- .../dataflow/builder/test_node_evaluator.py | 5 +- .../source_scan/test_cm_branch_combiner.py | 7 +- .../source_scan/test_source_scan_optimizer.py | 65 ++-- metricflow/test/examples/test_node_sql.py | 7 +- metricflow/test/fixtures/dataflow_fixtures.py | 13 +- metricflow/test/fixtures/model_fixtures.py | 24 +- metricflow/test/plan_conversion/conftest.py | 5 +- .../test_metric_time_dimension_to_sql.py | 7 +- .../test_dataflow_to_execution.py | 15 +- .../test_dataflow_to_sql_plan.py | 279 +++++++------- 29 files changed, 594 insertions(+), 761 deletions(-) diff --git a/metricflow/dataflow/builder/costing.py b/metricflow/dataflow/builder/costing.py index 8f2378afdb..981bf406d5 100644 --- a/metricflow/dataflow/builder/costing.py +++ b/metricflow/dataflow/builder/costing.py @@ -11,7 +11,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Generic, Sequence +from typing import Any, Sequence from metricflow.dataflow.dataflow_plan import ( AggregateMeasuresNode, @@ -29,7 +29,6 @@ OrderByLimitNode, ReadSqlSourceNode, SemiAdditiveJoinNode, - SourceDataSetT, WhereConstraintNode, WriteToResultDataframeNode, WriteToResultTableNode, @@ -71,30 +70,29 @@ def sum(costs: Sequence[DefaultCost]) -> DefaultCost: # noqa: D ) -class DataflowPlanNodeCostFunction(Generic[SourceDataSetT], ABC): +class DataflowPlanNodeCostFunction(ABC): """A function that calculates the cost for computing the dataflow up to a given node.""" @abstractmethod - def calculate_cost(self, node: DataflowPlanNode[SourceDataSetT]) -> DataflowPlanNodeCost: + def calculate_cost(self, node: DataflowPlanNode) -> DataflowPlanNodeCost: """Return the cost for calculating the given dataflow up to the given node.""" pass class DefaultCostFunction( - Generic[SourceDataSetT], - DataflowPlanNodeCostFunction[SourceDataSetT], - DataflowPlanNodeVisitor[SourceDataSetT, DefaultCost], + DataflowPlanNodeCostFunction, + DataflowPlanNodeVisitor[DefaultCost], ): """Cost function using the default cost.""" - def calculate_cost(self, node: DataflowPlanNode[SourceDataSetT]) -> DataflowPlanNodeCost: # noqa: D + def calculate_cost(self, node: DataflowPlanNode) -> DataflowPlanNodeCost: # noqa: D return node.accept(self) - def visit_source_node(self, node: ReadSqlSourceNode[SourceDataSetT]) -> DefaultCost: # noqa: D + def visit_source_node(self, node: ReadSqlSourceNode) -> DefaultCost: # noqa: D # Base case. return DefaultCost(num_joins=0, num_aggregations=0) - def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode[SourceDataSetT]) -> DefaultCost: # noqa: D + def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> DefaultCost: # noqa: D parent_costs = [x.accept(self) for x in node.parent_nodes] # Add number of joins to the cost. @@ -102,7 +100,7 @@ def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode[SourceDataSe return DefaultCost.sum(parent_costs + [node_cost]) def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D - self, node: JoinAggregatedMeasuresByGroupByColumnsNode[SourceDataSetT] + self, node: JoinAggregatedMeasuresByGroupByColumnsNode ) -> DefaultCost: parent_costs = [x.accept(self) for x in node.parent_nodes] @@ -111,40 +109,38 @@ def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D node_cost = DefaultCost(num_joins=num_joins) return DefaultCost.sum(parent_costs + [node_cost]) - def visit_aggregate_measures_node(self, node: AggregateMeasuresNode[SourceDataSetT]) -> DefaultCost: # noqa: D + def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> DefaultCost: # noqa: D parent_costs = [x.accept(self) for x in node.parent_nodes] # Add the number of aggregations to the cost node_cost = DefaultCost(num_aggregations=1) return DefaultCost.sum(parent_costs + [node_cost]) - def visit_compute_metrics_node(self, node: ComputeMetricsNode[SourceDataSetT]) -> DefaultCost: # noqa: D + def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> DefaultCost: # noqa: D return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - def visit_order_by_limit_node(self, node: OrderByLimitNode[SourceDataSetT]) -> DefaultCost: # noqa: D + def visit_order_by_limit_node(self, node: OrderByLimitNode) -> DefaultCost: # noqa: D return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - def visit_where_constraint_node(self, node: WhereConstraintNode[SourceDataSetT]) -> DefaultCost: # noqa: D + def visit_where_constraint_node(self, node: WhereConstraintNode) -> DefaultCost: # noqa: D return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - def visit_write_to_result_dataframe_node( # noqa: D - self, node: WriteToResultDataframeNode[SourceDataSetT] - ) -> DefaultCost: + def visit_write_to_result_dataframe_node(self, node: WriteToResultDataframeNode) -> DefaultCost: # noqa: D return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - def visit_write_to_result_table_node(self, node: WriteToResultTableNode[SourceDataSetT]) -> DefaultCost: # noqa: D + def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> DefaultCost: # noqa: D return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - def visit_pass_elements_filter_node(self, node: FilterElementsNode[SourceDataSetT]) -> DefaultCost: # noqa: D + def visit_pass_elements_filter_node(self, node: FilterElementsNode) -> DefaultCost: # noqa: D return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - def visit_combine_metrics_node(self, node: CombineMetricsNode[SourceDataSetT]) -> DefaultCost: # noqa: D + def visit_combine_metrics_node(self, node: CombineMetricsNode) -> DefaultCost: # noqa: D return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode[SourceDataSetT]) -> DefaultCost: # noqa: D + def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> DefaultCost: # noqa: D return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode[SourceDataSetT]) -> DefaultCost: # noqa: D + def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> DefaultCost: # noqa: D parent_costs = [x.accept(self) for x in node.parent_nodes] # Add the number of aggregations to the cost (eg 1 per unit time) @@ -152,16 +148,16 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode[SourceData return DefaultCost.sum(parent_costs + [node_cost]) def visit_metric_time_dimension_transform_node( # noqa: D - self, node: MetricTimeDimensionTransformNode[SourceDataSetT] + self, node: MetricTimeDimensionTransformNode ) -> DefaultCost: return DefaultCost.sum([x.accept(self) for x in node.parent_nodes]) - def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode[SourceDataSetT]) -> DefaultCost: # noqa: D + def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> DefaultCost: # noqa: D parent_costs = [x.accept(self) for x in node.parent_nodes] # Add number of joins to the cost. node_cost = DefaultCost(num_joins=1) return DefaultCost.sum(parent_costs + [node_cost]) - def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode[SourceDataSetT]) -> DefaultCost: # noqa: D + def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> DefaultCost: # noqa: D return DefaultCost.sum([x.accept(self) for x in node.parent_nodes] + [DefaultCost(num_joins=1)]) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 1445c7918f..004283db1e 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -4,7 +4,7 @@ import logging import time from dataclasses import dataclass -from typing import DefaultDict, Dict, Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union +from typing import DefaultDict, Dict, List, Optional, Sequence, Set, Tuple, Union from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.pretty_print import pformat_big_objects @@ -51,7 +51,6 @@ from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver from metricflow.plan_conversion.node_processor import PreDimensionJoinNodeProcessor -from metricflow.plan_conversion.sql_dataset import SqlDataSet from metricflow.specs.column_assoc import ColumnAssociationResolver from metricflow.specs.specs import ( DimensionSpec, @@ -73,12 +72,9 @@ logger = logging.getLogger(__name__) -# The type of data set that at the source nodes. -SqlDataSetT = TypeVar("SqlDataSetT", bound=SqlDataSet) - @dataclass(frozen=True) -class MeasureRecipe(Generic[SqlDataSetT]): +class MeasureRecipe: """Get a recipe for how to build a dataflow plan node that outputs measures and the needed linkable instances. The recipe involves filtering the measure node so that it only outputs the measures and the instances associated with @@ -86,7 +82,7 @@ class MeasureRecipe(Generic[SqlDataSetT]): in join_linkable_instances_recipes. """ - measure_node: BaseOutput[SqlDataSetT] + measure_node: BaseOutput required_local_linkable_specs: Tuple[LinkableInstanceSpec, ...] join_linkable_instances_recipes: Tuple[JoinLinkableInstancesRecipe, ...] @@ -101,15 +97,15 @@ class MeasureSpecProperties: non_additive_dimension_spec: Optional[NonAdditiveDimensionSpec] = None -class DataflowPlanBuilder(Generic[SqlDataSetT]): +class DataflowPlanBuilder: """Builds a dataflow plan to satisfy a given query.""" def __init__( # noqa: D self, - source_nodes: Sequence[BaseOutput[SqlDataSetT]], + source_nodes: Sequence[BaseOutput], semantic_manifest_lookup: SemanticManifestLookup, - cost_function: DataflowPlanNodeCostFunction = DefaultCostFunction[SqlDataSetT](), - node_output_resolver: Optional[DataflowPlanNodeOutputDataSetResolver[SqlDataSetT]] = None, + cost_function: DataflowPlanNodeCostFunction = DefaultCostFunction(), + node_output_resolver: Optional[DataflowPlanNodeOutputDataSetResolver] = None, column_association_resolver: Optional[ColumnAssociationResolver] = None, ) -> None: self._semantic_model_lookup = semantic_manifest_lookup.semantic_model_lookup @@ -123,7 +119,7 @@ def __init__( # noqa: D else column_association_resolver ) self._node_data_set_resolver = ( - DataflowPlanNodeOutputDataSetResolver[SqlDataSetT]( + DataflowPlanNodeOutputDataSetResolver( column_association_resolver=( DunderColumnAssociationResolver(semantic_manifest_lookup) if not column_association_resolver @@ -140,8 +136,8 @@ def build_plan( query_spec: MetricFlowQuerySpec, output_sql_table: Optional[SqlTable] = None, output_selection_specs: Optional[InstanceSpecSet] = None, - optimizers: Sequence[DataflowPlanOptimizer[SqlDataSetT]] = (), - ) -> DataflowPlan[SqlDataSetT]: + optimizers: Sequence[DataflowPlanOptimizer] = (), + ) -> DataflowPlan: """Generate a plan for reading the results of a query with the given spec into a dataframe or table.""" metrics_output_node = self._build_metrics_output_node( metric_specs=query_spec.metric_specs, @@ -177,7 +173,7 @@ def _build_metrics_output_node( where_constraint: Optional[WhereFilterSpec] = None, time_range_constraint: Optional[TimeRangeConstraint] = None, combine_metrics_join_type: SqlJoinType = SqlJoinType.FULL_OUTER, - ) -> BaseOutput[SqlDataSetT]: + ) -> BaseOutput: """Builds a computed metrics output node. Args: @@ -187,8 +183,8 @@ def _build_metrics_output_node( time_range_constraint: Time range constraint used to compute the metric. combine_metrics_join_type: The join used when combining the computed metrics. """ - output_nodes: List[BaseOutput[SqlDataSetT]] = [] - compute_metrics_node: Optional[ComputeMetricsNode[SqlDataSetT]] = None + output_nodes: List[BaseOutput] = [] + compute_metrics_node: Optional[ComputeMetricsNode] = None for metric_spec in metric_specs: logger.info(f"Generating compute metrics node for {metric_spec}") @@ -205,7 +201,7 @@ def _build_metrics_output_node( f"{pformat_big_objects(metric_input_specs=metric_input_specs)}" ) - compute_metrics_node = ComputeMetricsNode[SqlDataSetT]( + compute_metrics_node = ComputeMetricsNode( parent_node=self._build_metrics_output_node( metric_specs=metric_input_specs, queried_linkable_specs=queried_linkable_specs, @@ -258,7 +254,7 @@ def _build_metrics_output_node( if len(output_nodes) == 1: return output_nodes[0] - return CombineMetricsNode[SqlDataSetT]( + return CombineMetricsNode( parent_nodes=output_nodes, join_type=combine_metrics_join_type, ) @@ -271,7 +267,7 @@ def build_plan_for_distinct_values( entity_spec: Optional[EntitySpec] = None, time_range_constraint: Optional[TimeRangeConstraint] = None, limit: Optional[int] = None, - ) -> DataflowPlan[SqlDataSetT]: + ) -> DataflowPlan: """Generate a plan that would get the distinct values of a linkable instance. e.g. distinct listing__country_latest for bookings by listing__country_latest @@ -325,17 +321,17 @@ def build_plan_for_distinct_values( @staticmethod def build_sink_node_from_metrics_output_node( - computed_metrics_output: BaseOutput[SqlDataSetT], + computed_metrics_output: BaseOutput, order_by_specs: Sequence[OrderBySpec], output_sql_table: Optional[SqlTable] = None, limit: Optional[int] = None, output_selection_specs: Optional[InstanceSpecSet] = None, - ) -> SinkOutput[SqlDataSetT]: + ) -> SinkOutput: """Adds order by / limit / write nodes.""" - pre_result_node: Optional[BaseOutput[SqlDataSetT]] = None + pre_result_node: Optional[BaseOutput] = None if order_by_specs or limit: - pre_result_node = OrderByLimitNode[SqlDataSetT]( + pre_result_node = OrderByLimitNode( order_by_specs=list(order_by_specs), limit=limit, parent_node=computed_metrics_output, @@ -347,13 +343,13 @@ def build_sink_node_from_metrics_output_node( include_specs=output_selection_specs, ) - write_result_node: SinkOutput[SqlDataSetT] + write_result_node: SinkOutput if not output_sql_table: - write_result_node = WriteToResultDataframeNode[SqlDataSetT]( + write_result_node = WriteToResultDataframeNode( parent_node=pre_result_node or computed_metrics_output, ) else: - write_result_node = WriteToResultTableNode[SqlDataSetT]( + write_result_node = WriteToResultTableNode( parent_node=pre_result_node or computed_metrics_output, output_sql_table=output_sql_table, ) @@ -377,22 +373,22 @@ def _get_semantic_model_names_for_measures(self, measure_names: Sequence[Measure ) return semantic_model_names - def _sort_by_suitability(self, nodes: Sequence[BaseOutput[SqlDataSetT]]) -> Sequence[BaseOutput[SqlDataSetT]]: + def _sort_by_suitability(self, nodes: Sequence[BaseOutput]) -> Sequence[BaseOutput]: """Sort nodes by the cost, then by the number of linkable specs. Lower cost nodes will result in faster queries, and the lower the number of linkable specs means less aggregation required. """ - def sort_function(node: BaseOutput[SqlDataSetT]) -> Tuple[int, int]: + def sort_function(node: BaseOutput) -> Tuple[int, int]: data_set = self._node_data_set_resolver.get_output_data_set(node) return self._cost_function.calculate_cost(node).as_int, len(data_set.instance_set.spec_set.linkable_specs) return sorted(nodes, key=sort_function) def _select_source_nodes_with_measures( - self, measure_specs: Set[MeasureSpec], source_nodes: Sequence[BaseOutput[SqlDataSetT]] - ) -> Sequence[BaseOutput[SqlDataSetT]]: + self, measure_specs: Set[MeasureSpec], source_nodes: Sequence[BaseOutput] + ) -> Sequence[BaseOutput]: nodes = [] measure_specs_set = set(measure_specs) for source_node in source_nodes: @@ -472,10 +468,10 @@ def _find_measure_recipe( node_data_set_resolver=self._node_data_set_resolver, ) - source_nodes: Sequence[BaseOutput[SqlDataSetT]] = self._source_nodes + source_nodes: Sequence[BaseOutput] = self._source_nodes # We only care about nodes that have all required measures - potential_measure_nodes: Sequence[BaseOutput[SqlDataSetT]] = self._select_source_nodes_with_measures( + potential_measure_nodes: Sequence[BaseOutput] = self._select_source_nodes_with_measures( measure_specs=set(measure_specs), source_nodes=source_nodes ) @@ -563,7 +559,7 @@ def _find_measure_recipe( logger.info(f"Found {len(node_to_evaluation)} candidate measure nodes.") if len(node_to_evaluation) > 0: - cost_function = DefaultCostFunction[SqlDataSetT]() + cost_function = DefaultCostFunction() node_with_lowest_cost = min(node_to_evaluation, key=cost_function.calculate_cost) evaluation = node_to_evaluation[node_with_lowest_cost] @@ -606,10 +602,10 @@ def _find_measure_recipe( def build_computed_metrics_node( self, metric_spec: MetricSpec, - aggregated_measures_node: Union[AggregateMeasuresNode[SqlDataSetT], BaseOutput[SqlDataSetT]], - ) -> ComputeMetricsNode[SqlDataSetT]: + aggregated_measures_node: Union[AggregateMeasuresNode, BaseOutput], + ) -> ComputeMetricsNode: """Builds a ComputeMetricsNode from aggregated measures.""" - return ComputeMetricsNode[SqlDataSetT]( + return ComputeMetricsNode( parent_node=aggregated_measures_node, metric_specs=[metric_spec], ) @@ -624,14 +620,14 @@ def build_aggregated_measures( cumulative: Optional[bool] = False, cumulative_window: Optional[MetricTimeWindow] = None, cumulative_grain_to_date: Optional[TimeGranularity] = None, - ) -> BaseOutput[SqlDataSetT]: + ) -> BaseOutput: """Returns a node where the measures are aggregated by the linkable specs and constrained appropriately. This might be a node representing a single aggregation over one semantic model, or a node representing a composite set of aggregations originating from multiple semantic models, and joined into a single aggregated set of measures. """ - output_nodes: List[BaseOutput[SqlDataSetT]] = [] + output_nodes: List[BaseOutput] = [] semantic_models_and_constraints_to_measures: DefaultDict[ tuple[str, Optional[WhereFilterSpec]], List[MetricInputMeasureSpec] ] = collections.defaultdict(list) @@ -713,7 +709,7 @@ def _build_aggregated_measures_from_measure_source_node( cumulative: Optional[bool] = False, cumulative_window: Optional[MetricTimeWindow] = None, cumulative_grain_to_date: Optional[TimeGranularity] = None, - ) -> BaseOutput[SqlDataSetT]: + ) -> BaseOutput: metric_time_dimension_specs = [ time_dimension_spec for time_dimension_spec in queried_linkable_specs.time_dimension_specs @@ -780,7 +776,7 @@ def _build_aggregated_measures_from_measure_source_node( # If a cumulative metric is queried with metric_time, join over time range. # Otherwise, the measure will be aggregated over all time. - time_range_node: Optional[JoinOverTimeRangeNode[SqlDataSetT]] = None + time_range_node: Optional[JoinOverTimeRangeNode] = None if cumulative and metric_time_dimension_requested: time_range_node = JoinOverTimeRangeNode( parent_node=measure_recipe.measure_node, @@ -802,7 +798,7 @@ def _build_aggregated_measures_from_measure_source_node( ) # Only get the required measure and the local linkable instances so that aggregations work correctly. - filtered_measure_source_node = FilterElementsNode[SqlDataSetT]( + filtered_measure_source_node = FilterElementsNode( parent_node=join_to_time_spine_node or time_range_node or measure_recipe.measure_node, include_specs=InstanceSpecSet.merge( ( @@ -841,7 +837,7 @@ def _build_aggregated_measures_from_measure_source_node( # e.g. if the node is used to satisfy "user_id__country", then the node must have the entity # "user_id" and the "country" dimension so that it can be joined to the measure node. include_specs.extend([x.without_first_entity_link for x in join_recipe.satisfiable_linkable_specs]) - filtered_node_to_join = FilterElementsNode[SqlDataSetT]( + filtered_node_to_join = FilterElementsNode( parent_node=join_recipe.node_to_join, include_specs=InstanceSpecSet.create_from_linkable_specs(include_specs), ) @@ -855,9 +851,9 @@ def _build_aggregated_measures_from_measure_source_node( ) ) - unaggregated_measure_node: BaseOutput[SqlDataSetT] + unaggregated_measure_node: BaseOutput if len(join_targets) > 0: - filtered_measures_with_joined_elements = JoinToBaseOutputNode[SqlDataSetT]( + filtered_measures_with_joined_elements = JoinToBaseOutputNode( left_node=filtered_measure_source_node, join_targets=join_targets, ) @@ -869,7 +865,7 @@ def _build_aggregated_measures_from_measure_source_node( ) ) - after_join_filtered_node = FilterElementsNode[SqlDataSetT]( + after_join_filtered_node = FilterElementsNode( parent_node=filtered_measures_with_joined_elements, include_specs=specs_to_keep_after_join ) unaggregated_measure_node = after_join_filtered_node @@ -886,7 +882,7 @@ def _build_aggregated_measures_from_measure_source_node( unaggregated_measure_node, time_range_constraint ) - pre_aggregate_node: BaseOutput[SqlDataSetT] = cumulative_metric_constrained_node or unaggregated_measure_node + pre_aggregate_node: BaseOutput = cumulative_metric_constrained_node or unaggregated_measure_node if where_constraint: # Apply where constraint on the node pre_aggregate_node = WhereConstraintNode( @@ -908,7 +904,7 @@ def _build_aggregated_measures_from_measure_source_node( window_groupings = tuple( LinklessEntitySpec.from_element_name(name) for name in non_additive_dimension_spec.window_groupings ) - pre_aggregate_node = SemiAdditiveJoinNode[SqlDataSetT]( + pre_aggregate_node = SemiAdditiveJoinNode( parent_node=pre_aggregate_node, entity_specs=window_groupings, time_dimension_spec=time_dimension_spec, @@ -922,13 +918,13 @@ def _build_aggregated_measures_from_measure_source_node( # show up in the final result. # # e.g. for "bookings" by "ds" where "is_instant", "is_instant" should not be in the results. - pre_aggregate_node = FilterElementsNode[SqlDataSetT]( + pre_aggregate_node = FilterElementsNode( parent_node=pre_aggregate_node, include_specs=InstanceSpecSet.merge( (InstanceSpecSet(measure_specs=measure_specs), queried_linkable_specs.as_spec_set) ), ) - return AggregateMeasuresNode[SqlDataSetT]( + return AggregateMeasuresNode( parent_node=pre_aggregate_node, metric_input_measure_specs=tuple(metric_input_measure_specs), ) diff --git a/metricflow/dataflow/builder/node_data_set.py b/metricflow/dataflow/builder/node_data_set.py index 4c8343b385..2f1b5b5b5c 100644 --- a/metricflow/dataflow/builder/node_data_set.py +++ b/metricflow/dataflow/builder/node_data_set.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, Generic, TypeVar +from typing import Dict from metricflow.dataflow.dataflow_plan import ( DataflowPlanNode, @@ -10,10 +10,8 @@ from metricflow.plan_conversion.sql_dataset import SqlDataSet from metricflow.specs.column_assoc import ColumnAssociationResolver -SourceDataSetT = TypeVar("SourceDataSetT", bound=SqlDataSet) - -class DataflowPlanNodeOutputDataSetResolver(Generic[SourceDataSetT], DataflowToSqlQueryPlanConverter[SourceDataSetT]): +class DataflowPlanNodeOutputDataSetResolver(DataflowToSqlQueryPlanConverter): """Given a node in a dataflow plan, figure out what is the data set output by that node. Recall that in the dataflow plan, the nodes represent computation, and the inputs and outputs of the nodes are @@ -61,13 +59,13 @@ def __init__( # noqa: D column_association_resolver: ColumnAssociationResolver, semantic_manifest_lookup: SemanticManifestLookup, ) -> None: - self._node_to_output_data_set: Dict[DataflowPlanNode[SourceDataSetT], SqlDataSet] = {} + self._node_to_output_data_set: Dict[DataflowPlanNode, SqlDataSet] = {} super().__init__( column_association_resolver=column_association_resolver, semantic_manifest_lookup=semantic_manifest_lookup, ) - def get_output_data_set(self, node: DataflowPlanNode[SourceDataSetT]) -> SqlDataSet: # noqa: D + def get_output_data_set(self, node: DataflowPlanNode) -> SqlDataSet: # noqa: D """Cached since this will be called repeatedly during the computation of multiple metrics.""" if node not in self._node_to_output_data_set: self._node_to_output_data_set[node] = node.accept(self) diff --git a/metricflow/dataflow/builder/node_evaluator.py b/metricflow/dataflow/builder/node_evaluator.py index 9e9e23c80c..8e09e30c95 100644 --- a/metricflow/dataflow/builder/node_evaluator.py +++ b/metricflow/dataflow/builder/node_evaluator.py @@ -19,7 +19,7 @@ import itertools import logging from dataclasses import dataclass -from typing import Generic, List, Optional, Sequence, Tuple, TypeVar +from typing import List, Optional, Sequence, Tuple from dbt_semantic_interfaces.pretty_print import pformat_big_objects @@ -95,10 +95,7 @@ class LinkableInstanceSatisfiabilityEvaluation: unjoinable_linkable_specs: Tuple[LinkableInstanceSpec, ...] -SourceDataSetT = TypeVar("SourceDataSetT", bound=SqlDataSet) - - -class NodeEvaluatorForLinkableInstances(Generic[SourceDataSetT]): +class NodeEvaluatorForLinkableInstances: """Helps to evaluate if linkable instances can be obtained using the given node, with joins if necessary. For example, consider a "start_node" containing the "bookings" measure, "is_instant" dimension, and "listing_id" @@ -115,8 +112,8 @@ class NodeEvaluatorForLinkableInstances(Generic[SourceDataSetT]): def __init__( self, semantic_model_lookup: SemanticModelAccessor, - nodes_available_for_joins: Sequence[BaseOutput[SourceDataSetT]], - node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver[SourceDataSetT], + nodes_available_for_joins: Sequence[BaseOutput], + node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver, ) -> None: """Constructor. @@ -307,7 +304,7 @@ def _update_candidates_that_can_satisfy_linkable_specs( def evaluate_node( self, - start_node: BaseOutput[SourceDataSetT], + start_node: BaseOutput, required_linkable_specs: Sequence[LinkableInstanceSpec], ) -> LinkableInstanceSatisfiabilityEvaluation: """Evaluates if the "required_linkable_specs" can be realized by joining the "start_node" with other nodes. diff --git a/metricflow/dataflow/builder/source_node.py b/metricflow/dataflow/builder/source_node.py index bfe013fa06..e2f2da7790 100644 --- a/metricflow/dataflow/builder/source_node.py +++ b/metricflow/dataflow/builder/source_node.py @@ -21,13 +21,11 @@ class SourceNodeBuilder: def __init__(self, semantic_manifest_lookup: SemanticManifestLookup) -> None: # noqa: D self._semantic_manifest_lookup = semantic_manifest_lookup - def create_from_data_sets( - self, data_sets: Sequence[SemanticModelDataSet] - ) -> Sequence[BaseOutput[SemanticModelDataSet]]: + def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> Sequence[BaseOutput]: """Creates source nodes from SemanticModelDataSets.""" - source_nodes: List[BaseOutput[SemanticModelDataSet]] = [] + source_nodes: List[BaseOutput] = [] for data_set in data_sets: - read_node = ReadSqlSourceNode[SemanticModelDataSet](data_set) + read_node = ReadSqlSourceNode(data_set) agg_time_dim_to_measures_grouper = ( self._semantic_manifest_lookup.semantic_model_lookup.get_aggregation_time_dimensions_with_measures( data_set.semantic_model_reference @@ -42,7 +40,7 @@ def create_from_data_sets( # Splits the measures by distinct aggregate time dimension. for time_dimension_reference in time_dimension_references: source_nodes.append( - MetricTimeDimensionTransformNode[SemanticModelDataSet]( + MetricTimeDimensionTransformNode( parent_node=read_node, aggregation_time_dimension_reference=time_dimension_reference, ) diff --git a/metricflow/dataflow/dataflow_plan.py b/metricflow/dataflow/dataflow_plan.py index 71dd1dbe81..a7dde8ac62 100644 --- a/metricflow/dataflow/dataflow_plan.py +++ b/metricflow/dataflow/dataflow_plan.py @@ -38,8 +38,8 @@ PartitionTimeDimensionJoinDescription, ) from metricflow.dataflow.sql_table import SqlTable -from metricflow.dataset.dataset import DataSet from metricflow.filters.time_constraint import TimeRangeConstraint +from metricflow.plan_conversion.sql_dataset import SqlDataSet from metricflow.specs.specs import ( InstanceSpecSet, LinklessEntitySpec, @@ -54,12 +54,10 @@ logger = logging.getLogger(__name__) -# The type of data set that is flowing out of the source nodes -SourceDataSetT = TypeVar("SourceDataSetT", bound=DataSet) NodeSelfT = TypeVar("NodeSelfT", bound="DataflowPlanNode") -class DataflowPlanNode(Generic[SourceDataSetT], DagNode, Visitable, ABC): +class DataflowPlanNode(DagNode, Visitable, ABC): """A node in the graph representation of the dataflow. Each node in the graph performs an operation from the data that comes from the parent nodes, and the result is @@ -77,17 +75,17 @@ def __init__(self, node_id: NodeId, parent_nodes: List[DataflowPlanNode]) -> Non super().__init__(node_id=node_id) @property - def parent_nodes(self) -> Sequence[DataflowPlanNode[SourceDataSetT]]: + def parent_nodes(self) -> Sequence[DataflowPlanNode]: """Return the nodes where data for this node comes from.""" return self._parent_nodes @abstractmethod - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: """Called when a visitor needs to visit this node.""" raise NotImplementedError @abstractmethod - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: """Returns true if this node the same functionality as the other node. In other words, this returns true if all parameters (aside from parent_nodes) are the same. @@ -95,7 +93,7 @@ def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) - raise NotImplementedError @abstractmethod - def with_new_parents(self: NodeSelfT, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]]) -> NodeSelfT: + def with_new_parents(self: NodeSelfT, new_parent_nodes: Sequence[BaseOutput]) -> NodeSelfT: """Creates a node with the same behavior as this node, but with a different set of parents. typing.Self would be useful here, but not available in Python 3.8. @@ -108,7 +106,7 @@ def node_type(self) -> Type: # noqa: D return self.__class__ -class DataflowPlanNodeVisitor(Generic[SourceDataSetT, VisitorOutputT], ABC): +class DataflowPlanNodeVisitor(Generic[VisitorOutputT], ABC): """An object that can be used to visit the nodes of a dataflow plan. Follows the visitor pattern: https://en.wikipedia.org/wiki/Visitor_pattern @@ -117,81 +115,75 @@ class DataflowPlanNodeVisitor(Generic[SourceDataSetT, VisitorOutputT], ABC): """ @abstractmethod - def visit_source_node(self, node: ReadSqlSourceNode[SourceDataSetT]) -> VisitorOutputT: # noqa: D + def visit_source_node(self, node: ReadSqlSourceNode) -> VisitorOutputT: # noqa: D pass @abstractmethod - def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode[SourceDataSetT]) -> VisitorOutputT: # noqa: D + def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> VisitorOutputT: # noqa: D pass @abstractmethod def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D - self, node: JoinAggregatedMeasuresByGroupByColumnsNode[SourceDataSetT] + self, node: JoinAggregatedMeasuresByGroupByColumnsNode ) -> VisitorOutputT: pass @abstractmethod - def visit_aggregate_measures_node(self, node: AggregateMeasuresNode[SourceDataSetT]) -> VisitorOutputT: # noqa: D + def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> VisitorOutputT: # noqa: D pass @abstractmethod - def visit_compute_metrics_node(self, node: ComputeMetricsNode[SourceDataSetT]) -> VisitorOutputT: # noqa: D + def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> VisitorOutputT: # noqa: D pass @abstractmethod - def visit_order_by_limit_node(self, node: OrderByLimitNode[SourceDataSetT]) -> VisitorOutputT: # noqa: D + def visit_order_by_limit_node(self, node: OrderByLimitNode) -> VisitorOutputT: # noqa: D pass @abstractmethod - def visit_where_constraint_node(self, node: WhereConstraintNode[SourceDataSetT]) -> VisitorOutputT: # noqa: D + def visit_where_constraint_node(self, node: WhereConstraintNode) -> VisitorOutputT: # noqa: D pass @abstractmethod - def visit_write_to_result_dataframe_node( # noqa: D - self, node: WriteToResultDataframeNode[SourceDataSetT] - ) -> VisitorOutputT: + def visit_write_to_result_dataframe_node(self, node: WriteToResultDataframeNode) -> VisitorOutputT: # noqa: D pass @abstractmethod - def visit_write_to_result_table_node( # noqa: D - self, node: WriteToResultTableNode[SourceDataSetT] - ) -> VisitorOutputT: + def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> VisitorOutputT: # noqa: D pass @abstractmethod - def visit_pass_elements_filter_node(self, node: FilterElementsNode[SourceDataSetT]) -> VisitorOutputT: # noqa: D + def visit_pass_elements_filter_node(self, node: FilterElementsNode) -> VisitorOutputT: # noqa: D pass @abstractmethod - def visit_combine_metrics_node(self, node: CombineMetricsNode[SourceDataSetT]) -> VisitorOutputT: # noqa: D + def visit_combine_metrics_node(self, node: CombineMetricsNode) -> VisitorOutputT: # noqa: D pass @abstractmethod - def visit_constrain_time_range_node( # noqa: D - self, node: ConstrainTimeRangeNode[SourceDataSetT] - ) -> VisitorOutputT: + def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> VisitorOutputT: # noqa: D pass @abstractmethod - def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode[SourceDataSetT]) -> VisitorOutputT: # noqa: D + def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> VisitorOutputT: # noqa: D pass @abstractmethod - def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode[SourceDataSetT]) -> VisitorOutputT: # noqa: D + def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> VisitorOutputT: # noqa: D pass @abstractmethod def visit_metric_time_dimension_transform_node( # noqa: D - self, node: MetricTimeDimensionTransformNode[SourceDataSetT] + self, node: MetricTimeDimensionTransformNode ) -> VisitorOutputT: pass @abstractmethod - def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode[SourceDataSetT]) -> VisitorOutputT: # noqa: D + def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> VisitorOutputT: # noqa: D pass -class BaseOutput(Generic[SourceDataSetT], DataflowPlanNode[SourceDataSetT], ABC): +class BaseOutput(DataflowPlanNode, ABC): """A node that outputs data in a "base" format. The base format is where the columns represent un-aggregated measures, dimensions, and entities. @@ -200,10 +192,10 @@ class BaseOutput(Generic[SourceDataSetT], DataflowPlanNode[SourceDataSetT], ABC) pass -class ReadSqlSourceNode(Generic[SourceDataSetT], BaseOutput[SourceDataSetT]): +class ReadSqlSourceNode(BaseOutput): """A source node where data from an SQL table or SQL query is read and output.""" - def __init__(self, data_set: SourceDataSetT) -> None: + def __init__(self, data_set: SqlDataSet) -> None: """Constructor. Args: @@ -216,11 +208,11 @@ def __init__(self, data_set: SourceDataSetT) -> None: def id_prefix(cls) -> str: # noqa: D return DATAFLOW_NODE_READ_SQL_SOURCE_ID_PREFIX - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_source_node(self) @property - def data_set(self) -> SourceDataSetT: + def data_set(self) -> SqlDataSet: """Return the data set that this source represents and is passed to the child nodes.""" return self._dataset @@ -243,14 +235,12 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D DisplayedProperty("data_set", self.data_set), ] - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D return isinstance(other_node, self.__class__) and other_node.data_set == self.data_set - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> ReadSqlSourceNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> ReadSqlSourceNode: # noqa: D assert len(new_parent_nodes) == 0 - return ReadSqlSourceNode[SourceDataSetT](data_set=self.data_set) + return ReadSqlSourceNode(data_set=self.data_set) @dataclass(frozen=True) @@ -262,10 +252,10 @@ class ValidityWindowJoinDescription: @dataclass(frozen=True) -class JoinDescription(Generic[SourceDataSetT]): +class JoinDescription: """Describes how data from a node should be joined to data from another node.""" - join_node: BaseOutput[SourceDataSetT] + join_node: BaseOutput join_on_entity: LinklessEntitySpec join_on_partition_dimensions: Tuple[PartitionDimensionJoinDescription, ...] @@ -274,13 +264,13 @@ class JoinDescription(Generic[SourceDataSetT]): validity_window: Optional[ValidityWindowJoinDescription] = None -class JoinToBaseOutputNode(Generic[SourceDataSetT], BaseOutput[SourceDataSetT]): +class JoinToBaseOutputNode(BaseOutput): """A node that joins data from other nodes to a standard output node, one by one via entity.""" def __init__( self, - left_node: BaseOutput[SourceDataSetT], - join_targets: List[JoinDescription[SourceDataSetT]], + left_node: BaseOutput, + join_targets: List[JoinDescription], node_id: Optional[NodeId] = None, ) -> None: """Constructor. @@ -294,7 +284,7 @@ def __init__( self._join_targets = join_targets # Doing a list comprehension throws a type error, so doing it this way. - parent_nodes: List[DataflowPlanNode[SourceDataSetT]] = [self._left_node] + parent_nodes: List[DataflowPlanNode] = [self._left_node] for join_target in self._join_targets: parent_nodes.append(join_target.join_node) super().__init__(node_id=node_id or self.create_unique_id(), parent_nodes=parent_nodes) @@ -303,7 +293,7 @@ def __init__( def id_prefix(cls) -> str: # noqa: D return DATAFLOW_NODE_JOIN_TO_STANDARD_OUTPUT_ID_PREFIX - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_join_to_base_output_node(self) @property @@ -311,7 +301,7 @@ def description(self) -> str: # noqa: D return """Join Standard Outputs""" @property - def left_node(self) -> BaseOutput[SourceDataSetT]: # noqa: D + def left_node(self) -> BaseOutput: # noqa: D return self._left_node @property @@ -325,7 +315,7 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D for i, join_description in enumerate(self._join_targets) ] - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D if not isinstance(other_node, self.__class__) or len(self.join_targets) != len(other_node.join_targets): return False @@ -341,15 +331,13 @@ def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) - return False return True - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> JoinToBaseOutputNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> JoinToBaseOutputNode: # noqa: D assert len(new_parent_nodes) > 1 new_left_node = new_parent_nodes[0] new_join_nodes = new_parent_nodes[1:] assert len(new_join_nodes) == len(self._join_targets) - return JoinToBaseOutputNode[SourceDataSetT]( + return JoinToBaseOutputNode( left_node=new_left_node, join_targets=[ JoinDescription( @@ -364,12 +352,12 @@ def with_new_parents( # noqa: D ) -class JoinOverTimeRangeNode(Generic[SourceDataSetT], BaseOutput[SourceDataSetT]): +class JoinOverTimeRangeNode(BaseOutput): """A node that allows for cumulative metric computation by doing a self join across a cumulative date range.""" def __init__( self, - parent_node: BaseOutput[SourceDataSetT], + parent_node: BaseOutput, window: Optional[MetricTimeWindow], grain_to_date: Optional[TimeGranularity], node_id: Optional[NodeId] = None, @@ -396,14 +384,14 @@ def __init__( self.time_range_constraint = time_range_constraint # Doing a list comprehension throws a type error, so doing it this way. - parent_nodes: List[DataflowPlanNode[SourceDataSetT]] = [self._parent_node] + parent_nodes: List[DataflowPlanNode] = [self._parent_node] super().__init__(node_id=node_id or self.create_unique_id(), parent_nodes=parent_nodes) @classmethod def id_prefix(cls) -> str: # noqa: D return DATAFLOW_NODE_JOIN_SELF_OVER_TIME_RANGE_ID_PREFIX - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_join_over_time_range_node(self) @property @@ -415,7 +403,7 @@ def description(self) -> str: # noqa: D return """Join Self Over Time Range""" @property - def parent_node(self) -> BaseOutput[SourceDataSetT]: # noqa: D + def parent_node(self) -> BaseOutput: # noqa: D return self._parent_node @property @@ -426,7 +414,7 @@ def window(self) -> Optional[MetricTimeWindow]: # noqa: D def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D return super().displayed_properties - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D return ( isinstance(other_node, self.__class__) and other_node.grain_to_date == self.grain_to_date @@ -434,11 +422,9 @@ def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) - and other_node.time_range_constraint == self.time_range_constraint ) - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> JoinOverTimeRangeNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> JoinOverTimeRangeNode: # noqa: D assert len(new_parent_nodes) == 1 - return JoinOverTimeRangeNode[SourceDataSetT]( + return JoinOverTimeRangeNode( parent_node=new_parent_nodes[0], window=self.window, grain_to_date=self.grain_to_date, @@ -446,7 +432,7 @@ def with_new_parents( # noqa: D ) -class AggregatedMeasuresOutput(Generic[SourceDataSetT], BaseOutput[SourceDataSetT], ABC): +class AggregatedMeasuresOutput(BaseOutput, ABC): """A node that outputs data where the measures are aggregated. The measures are aggregated with respect to the present entities and dimensions. @@ -455,7 +441,7 @@ class AggregatedMeasuresOutput(Generic[SourceDataSetT], BaseOutput[SourceDataSet pass -class AggregateMeasuresNode(Generic[SourceDataSetT], AggregatedMeasuresOutput[SourceDataSetT]): +class AggregateMeasuresNode(AggregatedMeasuresOutput): """A node that aggregates the measures by the associated group by elements. In the event that one or more of the aggregated input measures has an alias assigned to it, any output query @@ -480,7 +466,7 @@ def __init__(self, parent_node: BaseOutput, metric_input_measure_specs: Tuple[Me def id_prefix(cls) -> str: # noqa: D return DATAFLOW_NODE_AGGREGATE_MEASURES_ID_PREFIX - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_aggregate_measures_node(self) @property @@ -499,23 +485,21 @@ def metric_input_measure_specs(self) -> Tuple[MetricInputMeasureSpec, ...]: """ return self._metric_input_measure_specs - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D return ( isinstance(other_node, self.__class__) and other_node.metric_input_measure_specs == self.metric_input_measure_specs ) - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> AggregateMeasuresNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> AggregateMeasuresNode: # noqa: D assert len(new_parent_nodes) == 1 - return AggregateMeasuresNode[SourceDataSetT]( + return AggregateMeasuresNode( parent_node=new_parent_nodes[0], metric_input_measure_specs=self.metric_input_measure_specs, ) -class JoinAggregatedMeasuresByGroupByColumnsNode(Generic[SourceDataSetT], AggregatedMeasuresOutput[SourceDataSetT]): +class JoinAggregatedMeasuresByGroupByColumnsNode(AggregatedMeasuresOutput): """A node that joins aggregated measures with group by elements. This is designed to link two separate semantic models with measures aggregated by the complete set of group by @@ -527,7 +511,7 @@ class JoinAggregatedMeasuresByGroupByColumnsNode(Generic[SourceDataSetT], Aggreg def __init__( self, - parent_nodes: Sequence[BaseOutput[SourceDataSetT]], + parent_nodes: Sequence[BaseOutput], ): """Constructor. @@ -545,7 +529,7 @@ def __init__( def id_prefix(cls) -> str: # noqa: D return DATAFLOW_NODE_JOIN_AGGREGATED_MEASURES_BY_GROUPBY_COLUMNS_PREFIX - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_join_aggregated_measures_by_groupby_columns_node(self) @property @@ -558,18 +542,18 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D DisplayedProperty("Join aggregated measure nodes: ", f"{[node.node_id for node in self.parent_nodes]}") ] - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D return isinstance(other_node, self.__class__) def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> JoinAggregatedMeasuresByGroupByColumnsNode[SourceDataSetT]: - return JoinAggregatedMeasuresByGroupByColumnsNode[SourceDataSetT]( + self, new_parent_nodes: Sequence[BaseOutput] + ) -> JoinAggregatedMeasuresByGroupByColumnsNode: + return JoinAggregatedMeasuresByGroupByColumnsNode( parent_nodes=new_parent_nodes, ) -class SemiAdditiveJoinNode(Generic[SourceDataSetT], BaseOutput[SourceDataSetT]): +class SemiAdditiveJoinNode(BaseOutput): """A node that performs a row filter by aggregating a given non-additive dimension. This is designed to filter a dataset down to singular non-additive time dimension values by aggregating @@ -620,7 +604,7 @@ class SemiAdditiveJoinNode(Generic[SourceDataSetT], BaseOutput[SourceDataSetT]): def __init__( self, - parent_node: BaseOutput[SourceDataSetT], + parent_node: BaseOutput, entity_specs: Sequence[LinklessEntitySpec], time_dimension_spec: TimeDimensionSpec, agg_by_function: AggregationType, @@ -642,14 +626,14 @@ def __init__( self._queried_time_dimension_spec = queried_time_dimension_spec # Doing a list comprehension throws a type error, so doing it this way. - parent_nodes: List[DataflowPlanNode[SourceDataSetT]] = [self._parent_node] + parent_nodes: List[DataflowPlanNode] = [self._parent_node] super().__init__(node_id=self.create_unique_id(), parent_nodes=parent_nodes) @classmethod def id_prefix(cls) -> str: # noqa: D return DATAFLOW_NODE_SEMI_ADDITIVE_JOIN_ID_PREFIX - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_semi_additive_join_node(self) @property @@ -657,7 +641,7 @@ def description(self) -> str: # noqa: D return f"""Join on {self.agg_by_function.name}({self.time_dimension_spec.element_name}) and {[i.element_name for i in self.entity_specs]} grouping by {self.queried_time_dimension_spec.element_name if self.queried_time_dimension_spec else None}""" @property - def parent_node(self) -> BaseOutput[SourceDataSetT]: # noqa: D + def parent_node(self) -> BaseOutput: # noqa: D return self._parent_node @property @@ -680,7 +664,7 @@ def queried_time_dimension_spec(self) -> Optional[TimeDimensionSpec]: # noqa: D def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D return super().displayed_properties - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D if not isinstance(other_node, self.__class__): return False @@ -692,12 +676,10 @@ def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) - and other_node.queried_time_dimension_spec == self.queried_time_dimension_spec ) - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> SemiAdditiveJoinNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> SemiAdditiveJoinNode: # noqa: D assert len(new_parent_nodes) == 1 - return SemiAdditiveJoinNode[SourceDataSetT]( + return SemiAdditiveJoinNode( parent_node=new_parent_nodes[0], entity_specs=self.entity_specs, time_dimension_spec=self.time_dimension_spec, @@ -706,18 +688,18 @@ def with_new_parents( # noqa: D ) -class ComputedMetricsOutput(Generic[SourceDataSetT], BaseOutput[SourceDataSetT], ABC): +class ComputedMetricsOutput(BaseOutput, ABC): """A node that outputs data that contains metrics computed from measures.""" pass -class JoinToTimeSpineNode(Generic[SourceDataSetT], BaseOutput[SourceDataSetT], ABC): +class JoinToTimeSpineNode(BaseOutput, ABC): """Join parent dataset to time spine dataset.""" def __init__( self, - parent_node: BaseOutput[SourceDataSetT], + parent_node: BaseOutput, metric_time_dimension_specs: List[TimeDimensionSpec], time_range_constraint: Optional[TimeRangeConstraint] = None, offset_window: Optional[MetricTimeWindow] = None, @@ -769,7 +751,7 @@ def offset_to_grain(self) -> Optional[TimeGranularity]: # noqa: D """Time range constraint to apply when querying time spine table.""" return self._offset_to_grain - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_join_to_time_spine_node(self) @property @@ -788,7 +770,7 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D def parent_node(self) -> BaseOutput: # noqa: D return self._parent_node - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D return ( isinstance(other_node, self.__class__) and other_node.time_range_constraint == self.time_range_constraint @@ -797,11 +779,9 @@ def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) - and other_node.metric_time_dimension_specs == self.metric_time_dimension_specs ) - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> JoinToTimeSpineNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> JoinToTimeSpineNode: # noqa: D assert len(new_parent_nodes) == 1 - return JoinToTimeSpineNode[SourceDataSetT]( + return JoinToTimeSpineNode( parent_node=new_parent_nodes[0], metric_time_dimension_specs=self.metric_time_dimension_specs, time_range_constraint=self.time_range_constraint, @@ -810,10 +790,10 @@ def with_new_parents( # noqa: D ) -class ComputeMetricsNode(Generic[SourceDataSetT], ComputedMetricsOutput[SourceDataSetT]): +class ComputeMetricsNode(ComputedMetricsOutput): """A node that computes metrics from input measures. Dimensions / entities are passed through.""" - def __init__(self, parent_node: BaseOutput[SourceDataSetT], metric_specs: List[MetricSpec]) -> None: # noqa: D + def __init__(self, parent_node: BaseOutput, metric_specs: List[MetricSpec]) -> None: # noqa: D """Constructor. Args: @@ -833,7 +813,7 @@ def metric_specs(self) -> List[MetricSpec]: # noqa: D """The metric instances that this node is supposed to compute and should have in the output.""" return self._metric_specs - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_compute_metrics_node(self) @property @@ -850,7 +830,7 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D def parent_node(self) -> BaseOutput: # noqa: D return self._parent_node - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D if not isinstance(other_node, self.__class__): return False @@ -859,23 +839,21 @@ def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) - return isinstance(other_node, self.__class__) and other_node.metric_specs == self.metric_specs - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> ComputeMetricsNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> ComputeMetricsNode: # noqa: D assert len(new_parent_nodes) == 1 - return ComputeMetricsNode[SourceDataSetT]( + return ComputeMetricsNode( parent_node=new_parent_nodes[0], metric_specs=self.metric_specs, ) -class OrderByLimitNode(Generic[SourceDataSetT], ComputedMetricsOutput[SourceDataSetT]): +class OrderByLimitNode(ComputedMetricsOutput): """A node that re-orders the input data with a limit.""" def __init__( self, order_by_specs: List[OrderBySpec], - parent_node: Union[BaseOutput[SourceDataSetT], ComputedMetricsOutput[SourceDataSetT]], + parent_node: Union[BaseOutput, ComputedMetricsOutput], limit: Optional[int] = None, ) -> None: """Constructor. @@ -904,7 +882,7 @@ def limit(self) -> Optional[int]: """The number of rows to limit by.""" return self._limit - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_order_by_limit_node(self) @property @@ -922,29 +900,27 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D ) @property - def parent_node(self) -> Union[BaseOutput[SourceDataSetT], ComputedMetricsOutput[SourceDataSetT]]: # noqa: D + def parent_node(self) -> Union[BaseOutput, ComputedMetricsOutput]: # noqa: D return self._parent_node - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D return ( isinstance(other_node, self.__class__) and other_node.order_by_specs == self.order_by_specs and other_node.limit == self.limit ) - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> OrderByLimitNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> OrderByLimitNode: # noqa: D assert len(new_parent_nodes) == 1 - return OrderByLimitNode[SourceDataSetT]( + return OrderByLimitNode( parent_node=new_parent_nodes[0], order_by_specs=self.order_by_specs, limit=self.limit, ) -class MetricTimeDimensionTransformNode(Generic[SourceDataSetT], BaseOutput[SourceDataSetT]): +class MetricTimeDimensionTransformNode(BaseOutput): """A node transforms the input data set so that it contains the metric time dimension and relevant measures. The metric time dimension is used later to aggregate all measures in the data set. @@ -957,7 +933,7 @@ class MetricTimeDimensionTransformNode(Generic[SourceDataSetT], BaseOutput[Sourc def __init__( # noqa: D self, - parent_node: BaseOutput[SourceDataSetT], + parent_node: BaseOutput, aggregation_time_dimension_reference: TimeDimensionReference, ) -> None: self._aggregation_time_dimension_reference = aggregation_time_dimension_reference @@ -973,7 +949,7 @@ def aggregation_time_dimension_reference(self) -> TimeDimensionReference: """The time dimension that measures in the input should be aggregated to.""" return self._aggregation_time_dimension_reference - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_metric_time_dimension_transform_node(self) @property @@ -987,18 +963,16 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D ] @property - def parent_node(self) -> BaseOutput[SourceDataSetT]: # noqa: D + def parent_node(self) -> BaseOutput: # noqa: D return self._parent_node - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D return ( isinstance(other_node, self.__class__) and other_node.aggregation_time_dimension_reference == self.aggregation_time_dimension_reference ) - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> MetricTimeDimensionTransformNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> MetricTimeDimensionTransformNode: # noqa: D assert len(new_parent_nodes) == 1 return MetricTimeDimensionTransformNode( parent_node=new_parent_nodes[0], @@ -1006,41 +980,35 @@ def with_new_parents( # noqa: D ) -class SinkNodeVisitor(Generic[SourceDataSetT, VisitorOutputT], ABC): +class SinkNodeVisitor(Generic[VisitorOutputT], ABC): """Similar to DataflowPlanNodeVisitor, but only for sink nodes.""" @abstractmethod - def visit_write_to_result_dataframe_node( # noqa: D - self, node: WriteToResultDataframeNode[SourceDataSetT] - ) -> VisitorOutputT: + def visit_write_to_result_dataframe_node(self, node: WriteToResultDataframeNode) -> VisitorOutputT: # noqa: D pass @abstractmethod - def visit_write_to_result_table_node( # noqa: D - self, node: WriteToResultTableNode[SourceDataSetT] - ) -> VisitorOutputT: + def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> VisitorOutputT: # noqa: D pass -class SinkOutput(Generic[SourceDataSetT], DataflowPlanNode[SourceDataSetT], ABC): +class SinkOutput(DataflowPlanNode, ABC): """A node where incoming data goes out of the graph.""" @abstractmethod - def accept_sink_node_visitor( # noqa: D - self, visitor: SinkNodeVisitor[SourceDataSetT, VisitorOutputT] - ) -> VisitorOutputT: + def accept_sink_node_visitor(self, visitor: SinkNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D pass @property @abstractmethod - def parent_node(self) -> BaseOutput[SourceDataSetT]: # noqa: D + def parent_node(self) -> BaseOutput: # noqa: D pass -class WriteToResultDataframeNode(Generic[SourceDataSetT], SinkOutput[SourceDataSetT]): +class WriteToResultDataframeNode(SinkOutput): """A node where incoming data gets written to a dataframe.""" - def __init__(self, parent_node: BaseOutput[SourceDataSetT]) -> None: # noqa: D + def __init__(self, parent_node: BaseOutput) -> None: # noqa: D self._parent_node = parent_node super().__init__(node_id=self.create_unique_id(), parent_nodes=[parent_node]) @@ -1048,7 +1016,7 @@ def __init__(self, parent_node: BaseOutput[SourceDataSetT]) -> None: # noqa: D def id_prefix(cls) -> str: # noqa: D return DATAFLOW_NODE_WRITE_TO_RESULT_DATAFRAME_ID_PREFIX - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_write_to_result_dataframe_node(self) @property @@ -1056,31 +1024,27 @@ def description(self) -> str: # noqa: D return """Write to Dataframe""" @property - def parent_node(self) -> BaseOutput[SourceDataSetT]: # noqa: D + def parent_node(self) -> BaseOutput: # noqa: D assert len(self.parent_nodes) == 1 return self._parent_node - def accept_sink_node_visitor( # noqa: D - self, visitor: SinkNodeVisitor[SourceDataSetT, VisitorOutputT] - ) -> VisitorOutputT: + def accept_sink_node_visitor(self, visitor: SinkNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_write_to_result_dataframe_node(self) - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D return isinstance(other_node, self.__class__) - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> WriteToResultDataframeNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> WriteToResultDataframeNode: # noqa: D assert len(new_parent_nodes) == 1 - return WriteToResultDataframeNode[SourceDataSetT](parent_node=new_parent_nodes[0]) + return WriteToResultDataframeNode(parent_node=new_parent_nodes[0]) -class WriteToResultTableNode(Generic[SourceDataSetT], SinkOutput[SourceDataSetT]): +class WriteToResultTableNode(SinkOutput): """A node where incoming data gets written to a table.""" def __init__( # noqa: D self, - parent_node: BaseOutput[SourceDataSetT], + parent_node: BaseOutput, output_sql_table: SqlTable, ) -> None: """Constructor. @@ -1097,7 +1061,7 @@ def __init__( # noqa: D def id_prefix(cls) -> str: # noqa: D return DATAFLOW_NODE_WRITE_TO_RESULT_DATAFRAME_ID_PREFIX - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_write_to_result_table_node(self) @property @@ -1105,37 +1069,33 @@ def description(self) -> str: # noqa: D return """Write to Table""" @property - def parent_node(self) -> BaseOutput[SourceDataSetT]: # noqa: D + def parent_node(self) -> BaseOutput: # noqa: D assert len(self.parent_nodes) == 1 return self._parent_node - def accept_sink_node_visitor( # noqa: D - self, visitor: SinkNodeVisitor[SourceDataSetT, VisitorOutputT] - ) -> VisitorOutputT: + def accept_sink_node_visitor(self, visitor: SinkNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_write_to_result_table_node(self) @property def output_sql_table(self) -> SqlTable: # noqa: D return self._output_sql_table - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D return isinstance(other_node, self.__class__) and other_node.output_sql_table == self.output_sql_table - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> WriteToResultTableNode: - return WriteToResultTableNode[SourceDataSetT]( + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> WriteToResultTableNode: # noqa: D + return WriteToResultTableNode( parent_node=new_parent_nodes[0], output_sql_table=self.output_sql_table, ) -class FilterElementsNode(Generic[SourceDataSetT], BaseOutput[SourceDataSetT]): +class FilterElementsNode(BaseOutput): """Only passes the listed elements.""" def __init__( # noqa: D self, - parent_node: BaseOutput[SourceDataSetT], + parent_node: BaseOutput, include_specs: InstanceSpecSet, replace_description: Optional[str] = None, ) -> None: @@ -1153,7 +1113,7 @@ def include_specs(self) -> InstanceSpecSet: """Returns the specs for the elements that it should pass.""" return self._include_specs - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_pass_elements_filter_node(self) @property @@ -1176,29 +1136,27 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D return super().displayed_properties + additional_properties @property - def parent_node(self) -> BaseOutput[SourceDataSetT]: # noqa: D + def parent_node(self) -> BaseOutput: # noqa: D return self._parent_node - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D return isinstance(other_node, self.__class__) and other_node.include_specs == self.include_specs - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> FilterElementsNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> FilterElementsNode: # noqa: D assert len(new_parent_nodes) == 1 - return FilterElementsNode[SourceDataSetT]( + return FilterElementsNode( parent_node=new_parent_nodes[0], include_specs=self.include_specs, replace_description=self._replace_description, ) -class WhereConstraintNode(AggregatedMeasuresOutput[SourceDataSetT]): +class WhereConstraintNode(AggregatedMeasuresOutput): """Remove rows using a WHERE clause.""" def __init__( # noqa: D self, - parent_node: BaseOutput[SourceDataSetT], + parent_node: BaseOutput, where_constraint: WhereFilterSpec, ) -> None: self._where = where_constraint @@ -1214,7 +1172,7 @@ def where(self) -> WhereFilterSpec: """Returns the specs for the elements that it should pass.""" return self._where - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_where_constraint_node(self) @property @@ -1227,12 +1185,10 @@ def description(self) -> str: # noqa: D def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D return super().displayed_properties + [DisplayedProperty("where_condition", self.where)] - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D return isinstance(other_node, self.__class__) and other_node.where == self.where - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> WhereConstraintNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> WhereConstraintNode: # noqa: D assert len(new_parent_nodes) == 1 return WhereConstraintNode( parent_node=new_parent_nodes[0], @@ -1240,12 +1196,12 @@ def with_new_parents( # noqa: D ) -class CombineMetricsNode(Generic[SourceDataSetT], ComputedMetricsOutput[SourceDataSetT]): +class CombineMetricsNode(ComputedMetricsOutput): """Combines metrics from different nodes into a single output.""" def __init__( # noqa: D self, - parent_nodes: Sequence[Union[BaseOutput, ComputedMetricsOutput[SourceDataSetT]]], + parent_nodes: Sequence[Union[BaseOutput, ComputedMetricsOutput]], join_type: SqlJoinType = SqlJoinType.FULL_OUTER, ) -> None: self._join_type = join_type @@ -1255,7 +1211,7 @@ def __init__( # noqa: D def id_prefix(cls) -> str: # noqa: D return DATAFLOW_NODE_COMBINE_METRICS_ID_PREFIX - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_combine_metrics_node(self) @property @@ -1278,12 +1234,10 @@ def join_type(self) -> SqlJoinType: """The type of join used for combining metrics.""" return self._join_type - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D return isinstance(other_node, self.__class__) and other_node.join_type == self.join_type - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> CombineMetricsNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> CombineMetricsNode: # noqa: D assert len(new_parent_nodes) == 1 return CombineMetricsNode( parent_nodes=new_parent_nodes, @@ -1291,7 +1245,7 @@ def with_new_parents( # noqa: D ) -class ConstrainTimeRangeNode(AggregatedMeasuresOutput[SourceDataSetT], BaseOutput[SourceDataSetT]): +class ConstrainTimeRangeNode(AggregatedMeasuresOutput, BaseOutput): """Constrains the time range of the input data set. For example, if the input data set had "sales by date", then this would restrict the data set so that it only @@ -1300,7 +1254,7 @@ class ConstrainTimeRangeNode(AggregatedMeasuresOutput[SourceDataSetT], BaseOutpu def __init__( # noqa: D self, - parent_node: BaseOutput[SourceDataSetT], + parent_node: BaseOutput, time_range_constraint: TimeRangeConstraint, ) -> None: self._time_range_constraint = time_range_constraint @@ -1310,7 +1264,7 @@ def __init__( # noqa: D def id_prefix(cls) -> str: # noqa: D return DATAFLOW_NODE_CONSTRAIN_TIME_RANGE_ID_PREFIX - def accept(self, visitor: DataflowPlanNodeVisitor[SourceDataSetT, VisitorOutputT]) -> VisitorOutputT: # noqa: D + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D return visitor.visit_constrain_time_range_node(self) @property @@ -1336,33 +1290,31 @@ def displayed_properties(self) -> List[DisplayedProperty]: # noqa: D DisplayedProperty("time_range_end", self.time_range_constraint.end_time.isoformat()), ] - def functionally_identical(self, other_node: DataflowPlanNode[SourceDataSetT]) -> bool: # noqa: D + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D return isinstance(other_node, self.__class__) and self.time_range_constraint == other_node.time_range_constraint - def with_new_parents( # noqa: D - self, new_parent_nodes: Sequence[BaseOutput[SourceDataSetT]] - ) -> ConstrainTimeRangeNode[SourceDataSetT]: + def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> ConstrainTimeRangeNode: # noqa: D assert len(new_parent_nodes) == 1 - return ConstrainTimeRangeNode[SourceDataSetT]( + return ConstrainTimeRangeNode( parent_node=new_parent_nodes[0], time_range_constraint=self.time_range_constraint, ) -class DataflowPlan(Generic[SourceDataSetT], MetricFlowDag[SinkOutput[SourceDataSetT]]): +class DataflowPlan(MetricFlowDag[SinkOutput]): """Describes the flow of metric data as it goes from source nodes to sink nodes in the graph.""" - def __init__(self, plan_id: str, sink_output_nodes: List[SinkOutput[SourceDataSetT]]) -> None: # noqa: D + def __init__(self, plan_id: str, sink_output_nodes: List[SinkOutput]) -> None: # noqa: D if len(sink_output_nodes) == 0: raise RuntimeError("Can't create a dataflow plan without sink node(s).") self._sink_output_nodes = sink_output_nodes super().__init__(dag_id=plan_id, sink_nodes=sink_output_nodes) @property - def sink_output_nodes(self) -> List[SinkOutput[SourceDataSetT]]: # noqa: D + def sink_output_nodes(self) -> List[SinkOutput]: # noqa: D return self._sink_output_nodes @property - def sink_output_node(self) -> SinkOutput[SourceDataSetT]: # noqa: D + def sink_output_node(self) -> SinkOutput: # noqa: D assert len(self._sink_output_nodes) == 1, f"Only 1 sink node supported. Got: {self._sink_output_nodes}" return self._sink_output_nodes[0] diff --git a/metricflow/dataflow/optimizer/dataflow_plan_optimizer.py b/metricflow/dataflow/optimizer/dataflow_plan_optimizer.py index 3cb32807d1..8d03f20c52 100644 --- a/metricflow/dataflow/optimizer/dataflow_plan_optimizer.py +++ b/metricflow/dataflow/optimizer/dataflow_plan_optimizer.py @@ -1,14 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Generic -from metricflow.dataflow.dataflow_plan import DataflowPlan, SourceDataSetT +from metricflow.dataflow.dataflow_plan import DataflowPlan -class DataflowPlanOptimizer(Generic[SourceDataSetT], ABC): +class DataflowPlanOptimizer(ABC): """Converts one dataflow plan into another dataflow plan that is more optimal in some way (e.g. performance).""" @abstractmethod - def optimize(self, dataflow_plan: DataflowPlan[SourceDataSetT]) -> DataflowPlan[SourceDataSetT]: # noqa: D + def optimize(self, dataflow_plan: DataflowPlan) -> DataflowPlan: # noqa: D raise NotImplementedError diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index 32ce7e16f0..398f3c559f 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass -from typing import Generic, List, Optional, Sequence +from typing import List, Optional, Sequence from metricflow.dataflow.dataflow_plan import ( AggregateMeasuresNode, @@ -21,7 +21,6 @@ OrderByLimitNode, ReadSqlSourceNode, SemiAdditiveJoinNode, - SourceDataSetT, WhereConstraintNode, WriteToResultDataframeNode, WriteToResultTableNode, @@ -33,10 +32,10 @@ @dataclass(frozen=True) -class ComputeMetricsBranchCombinerResult(Generic[SourceDataSetT]): # noqa: D +class ComputeMetricsBranchCombinerResult: # noqa: D # Perhaps adding more metadata about how nodes got combined would be useful. # If combined_branch is None, it means combination could not occur. - combined_branch: Optional[BaseOutput[SourceDataSetT]] = None + combined_branch: Optional[BaseOutput] = None @property def combined(self) -> bool: @@ -44,14 +43,12 @@ def combined(self) -> bool: return self.combined_branch is not None @property - def checked_combined_branch(self) -> BaseOutput[SourceDataSetT]: # noqa: D + def checked_combined_branch(self) -> BaseOutput: # noqa: D assert self.combined_branch is not None return self.combined_branch -class ComputeMetricsBranchCombiner( - Generic[SourceDataSetT], DataflowPlanNodeVisitor[SourceDataSetT, ComputeMetricsBranchCombinerResult] -): +class ComputeMetricsBranchCombiner(DataflowPlanNodeVisitor[ComputeMetricsBranchCombinerResult]): """Combines branches where the leaf node is a ComputeMetricsNode. This considers two branches, a left branch and a right branch. The left branch is supplied via the argument in the @@ -126,17 +123,17 @@ class ComputeMetricsBranchCombiner( is propagated up to the result at the root node. """ - def __init__(self, left_branch_node: BaseOutput[SourceDataSetT]) -> None: # noqa: D - self._current_left_node: DataflowPlanNode[SourceDataSetT] = left_branch_node + def __init__(self, left_branch_node: BaseOutput) -> None: # noqa: D + self._current_left_node: DataflowPlanNode = left_branch_node self._log_level = logging.DEBUG - def _log_visit_node_type(self, node: DataflowPlanNode[SourceDataSetT]) -> None: + def _log_visit_node_type(self, node: DataflowPlanNode) -> None: logger.log(level=self._log_level, msg=f"Visiting {node}") def _log_combine_failure( self, - left_node: DataflowPlanNode[SourceDataSetT], - right_node: DataflowPlanNode[SourceDataSetT], + left_node: DataflowPlanNode, + right_node: DataflowPlanNode, combine_failure_reason: str, ) -> None: logger.log( @@ -147,18 +144,16 @@ def _log_combine_failure( def _log_combine_success( self, - left_node: DataflowPlanNode[SourceDataSetT], - right_node: DataflowPlanNode[SourceDataSetT], - combined_node: DataflowPlanNode[SourceDataSetT], + left_node: DataflowPlanNode, + right_node: DataflowPlanNode, + combined_node: DataflowPlanNode, ) -> None: logger.log( level=self._log_level, msg=f"Combined left_node={left_node} right_node={right_node} combined_node: {combined_node}", ) - def _combine_parent_branches( - self, current_right_node: BaseOutput[SourceDataSetT] - ) -> Optional[Sequence[BaseOutput[SourceDataSetT]]]: + def _combine_parent_branches(self, current_right_node: BaseOutput) -> Optional[Sequence[BaseOutput]]: if len(self._current_left_node.parent_nodes) != len(current_right_node.parent_nodes): self._log_combine_failure( left_node=self._current_left_node, @@ -188,9 +183,7 @@ def _combine_parent_branches( return combined_parents - def _default_handler( # noqa: D - self, current_right_node: BaseOutput[SourceDataSetT] - ) -> ComputeMetricsBranchCombinerResult[SourceDataSetT]: + def _default_handler(self, current_right_node: BaseOutput) -> ComputeMetricsBranchCombinerResult: # noqa: D combined_parent_nodes = self._combine_parent_branches(current_right_node) if combined_parent_nodes is None: return ComputeMetricsBranchCombinerResult() @@ -204,7 +197,7 @@ def _default_handler( # noqa: D self._log_combine_success( left_node=self._current_left_node, right_node=current_right_node, combined_node=combined_node ) - return ComputeMetricsBranchCombinerResult[SourceDataSetT](combined_node) + return ComputeMetricsBranchCombinerResult(combined_node) self._log_combine_failure( left_node=self._current_left_node, @@ -213,26 +206,24 @@ def _default_handler( # noqa: D ) return ComputeMetricsBranchCombinerResult() - def visit_source_node( # noqa: D - self, node: ReadSqlSourceNode[SourceDataSetT] - ) -> ComputeMetricsBranchCombinerResult: + def visit_source_node(self, node: ReadSqlSourceNode) -> ComputeMetricsBranchCombinerResult: # noqa: D self._log_visit_node_type(node) return self._default_handler(node) def visit_join_to_base_output_node( # noqa: D - self, node: JoinToBaseOutputNode[SourceDataSetT] + self, node: JoinToBaseOutputNode ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._default_handler(node) def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D - self, node: JoinAggregatedMeasuresByGroupByColumnsNode[SourceDataSetT] + self, node: JoinAggregatedMeasuresByGroupByColumnsNode ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._default_handler(node) def visit_aggregate_measures_node( # noqa: D - self, node: AggregateMeasuresNode[SourceDataSetT] + self, node: AggregateMeasuresNode ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) current_right_node = node @@ -269,7 +260,7 @@ def visit_aggregate_measures_node( # noqa: D ) return ComputeMetricsBranchCombinerResult() - combined_node = AggregateMeasuresNode[SourceDataSetT]( + combined_node = AggregateMeasuresNode( parent_node=combined_parent_node, metric_input_measure_specs=combined_metric_input_measure_specs, ) @@ -280,9 +271,7 @@ def visit_aggregate_measures_node( # noqa: D ) return ComputeMetricsBranchCombinerResult(combined_node) - def visit_compute_metrics_node( # noqa: D - self, node: ComputeMetricsNode[SourceDataSetT] - ) -> ComputeMetricsBranchCombinerResult: + def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> ComputeMetricsBranchCombinerResult: # noqa: D current_right_node = node self._log_visit_node_type(current_right_node) combined_parent_nodes = self._combine_parent_branches(current_right_node) @@ -300,7 +289,7 @@ def visit_compute_metrics_node( # noqa: D assert len(combined_parent_nodes) == 1 combined_parent_node = combined_parent_nodes[0] assert combined_parent_node is not None - combined_node = ComputeMetricsNode[SourceDataSetT]( + combined_node = ComputeMetricsNode( parent_node=combined_parent_node, metric_specs=self._current_left_node.metric_specs + current_right_node.metric_specs, ) @@ -321,32 +310,28 @@ def _handle_unsupported_node(self, current_right_node: DataflowPlanNode) -> Comp ) return ComputeMetricsBranchCombinerResult() - def visit_order_by_limit_node( # noqa: D - self, node: OrderByLimitNode[SourceDataSetT] - ) -> ComputeMetricsBranchCombinerResult: + def visit_order_by_limit_node(self, node: OrderByLimitNode) -> ComputeMetricsBranchCombinerResult: # noqa: D self._log_visit_node_type(node) return self._handle_unsupported_node(node) - def visit_where_constraint_node( # noqa: D - self, node: WhereConstraintNode[SourceDataSetT] - ) -> ComputeMetricsBranchCombinerResult: + def visit_where_constraint_node(self, node: WhereConstraintNode) -> ComputeMetricsBranchCombinerResult: # noqa: D self._log_visit_node_type(node) return self._default_handler(node) def visit_write_to_result_dataframe_node( # noqa: D - self, node: WriteToResultDataframeNode[SourceDataSetT] + self, node: WriteToResultDataframeNode ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._handle_unsupported_node(node) def visit_write_to_result_table_node( # noqa: D - self, node: WriteToResultTableNode[SourceDataSetT] + self, node: WriteToResultTableNode ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._handle_unsupported_node(node) def visit_pass_elements_filter_node( # noqa: D - self, node: FilterElementsNode[SourceDataSetT] + self, node: FilterElementsNode ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) @@ -380,7 +365,7 @@ def visit_pass_elements_filter_node( # noqa: D # De-dupe so that we don't see the same spec twice in include specs. For example, this can happen with dimension # specs since any branch that is merged together needs to output the same set of dimensions. - combined_node = FilterElementsNode[SourceDataSetT]( + combined_node = FilterElementsNode( parent_node=combined_parent_node, include_specs=InstanceSpecSet.merge( (self._current_left_node.include_specs, current_right_node.include_specs) @@ -393,38 +378,34 @@ def visit_pass_elements_filter_node( # noqa: D ) return ComputeMetricsBranchCombinerResult(combined_node) - def visit_combine_metrics_node( # noqa: D - self, node: CombineMetricsNode[SourceDataSetT] - ) -> ComputeMetricsBranchCombinerResult: + def visit_combine_metrics_node(self, node: CombineMetricsNode) -> ComputeMetricsBranchCombinerResult: # noqa: D self._log_visit_node_type(node) return self._handle_unsupported_node(node) def visit_constrain_time_range_node( # noqa: D - self, node: ConstrainTimeRangeNode[SourceDataSetT] + self, node: ConstrainTimeRangeNode ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._default_handler(node) def visit_join_over_time_range_node( # noqa: D - self, node: JoinOverTimeRangeNode[SourceDataSetT] + self, node: JoinOverTimeRangeNode ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._default_handler(node) def visit_semi_additive_join_node( # noqa: D - self, node: SemiAdditiveJoinNode[SourceDataSetT] + self, node: SemiAdditiveJoinNode ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._default_handler(node) def visit_metric_time_dimension_transform_node( # noqa: D - self, node: MetricTimeDimensionTransformNode[SourceDataSetT] + self, node: MetricTimeDimensionTransformNode ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._default_handler(node) - def visit_join_to_time_spine_node( # noqa: D - self, node: JoinToTimeSpineNode[SourceDataSetT] - ) -> ComputeMetricsBranchCombinerResult[SourceDataSetT]: + def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> ComputeMetricsBranchCombinerResult: # noqa: D self._log_visit_node_type(node) return self._default_handler(node) diff --git a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py index 361e404235..5938c63836 100644 --- a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py +++ b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass -from typing import Generic, List, Optional, Sequence +from typing import List, Optional, Sequence from metricflow.dag.id_generation import OPTIMIZED_DATAFLOW_PLAN_PREFIX, IdGeneratorRegistry from metricflow.dataflow.dataflow_plan import ( @@ -24,7 +24,6 @@ ReadSqlSourceNode, SemiAdditiveJoinNode, SinkOutput, - SourceDataSetT, WhereConstraintNode, WriteToResultDataframeNode, WriteToResultTableNode, @@ -40,34 +39,33 @@ @dataclass(frozen=True) -class OptimizeBranchResult(Generic[SourceDataSetT]): # noqa: D - base_output_node: Optional[BaseOutput[SourceDataSetT]] = None - sink_node: Optional[SinkOutput[SourceDataSetT]] = None +class OptimizeBranchResult: # noqa: D + base_output_node: Optional[BaseOutput] = None + sink_node: Optional[SinkOutput] = None @property - def checked_base_output(self) -> BaseOutput[SourceDataSetT]: # noqa: D + def checked_base_output(self) -> BaseOutput: # noqa: D assert self.base_output_node, f"Expected the result of traversal to produce a {BaseOutput}" return self.base_output_node @property - def checked_sink_node(self) -> SinkOutput[SourceDataSetT]: # noqa: D + def checked_sink_node(self) -> SinkOutput: # noqa: D assert self.sink_node, f"Expected the result of traversal to produce a {SinkOutput}" return self.sink_node @dataclass(frozen=True) -class BranchCombinationResult(Generic[SourceDataSetT]): +class BranchCombinationResult: """Holds the results of combining a branch (right_branch) with one of the branches in a list (left_branch).""" - left_branch: BaseOutput[SourceDataSetT] - right_branch: BaseOutput[SourceDataSetT] - combined_branch: Optional[BaseOutput[SourceDataSetT]] = None + left_branch: BaseOutput + right_branch: BaseOutput + combined_branch: Optional[BaseOutput] = None class SourceScanOptimizer( - Generic[SourceDataSetT], - DataflowPlanNodeVisitor[SourceDataSetT, OptimizeBranchResult[SourceDataSetT]], - DataflowPlanOptimizer[SourceDataSetT], + DataflowPlanNodeVisitor[OptimizeBranchResult], + DataflowPlanOptimizer, ): """Reduces the number of scans (ReadSqlSourceNodes) in a dataflow plan. @@ -122,14 +120,14 @@ class SourceScanOptimizer( def __init__(self) -> None: # noqa: D self._log_level = logging.DEBUG - def _log_visit_node_type(self, node: DataflowPlanNode[SourceDataSetT]) -> None: + def _log_visit_node_type(self, node: DataflowPlanNode) -> None: logger.log(level=self._log_level, msg=f"Visiting {node}") def _default_base_output_handler( self, - node: BaseOutput[SourceDataSetT], - ) -> OptimizeBranchResult[SourceDataSetT]: - optimized_parents: Sequence[OptimizeBranchResult[SourceDataSetT]] = tuple( + node: BaseOutput, + ) -> OptimizeBranchResult: + optimized_parents: Sequence[OptimizeBranchResult] = tuple( parent_node.accept(self) for parent_node in node.parent_nodes ) # Parents should always be BaseOutput @@ -139,9 +137,9 @@ def _default_base_output_handler( def _default_sink_node_handler( self, - node: SinkOutput[SourceDataSetT], - ) -> OptimizeBranchResult[SourceDataSetT]: - optimized_parents: Sequence[OptimizeBranchResult[SourceDataSetT]] = tuple( + node: SinkOutput, + ) -> OptimizeBranchResult: + optimized_parents: Sequence[OptimizeBranchResult] = tuple( parent_node.accept(self) for parent_node in node.parent_nodes ) # Parents should always be BaseOutput @@ -149,38 +147,30 @@ def _default_sink_node_handler( sink_node=node.with_new_parents(tuple(x.checked_base_output for x in optimized_parents)) ) - def visit_source_node( # noqa: D - self, node: ReadSqlSourceNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_source_node(self, node: ReadSqlSourceNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) return self._default_base_output_handler(node) - def visit_join_to_base_output_node( # noqa: D - self, node: JoinToBaseOutputNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) return self._default_base_output_handler(node) def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D - self, node: JoinAggregatedMeasuresByGroupByColumnsNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + self, node: JoinAggregatedMeasuresByGroupByColumnsNode + ) -> OptimizeBranchResult: self._log_visit_node_type(node) return self._default_base_output_handler(node) - def visit_aggregate_measures_node( # noqa: D - self, node: AggregateMeasuresNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) return self._default_base_output_handler(node) - def visit_compute_metrics_node( # noqa: D - self, node: ComputeMetricsNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) # Run the optimizer on the parent branch to handle derived metrics, which are defined recursively in the DAG. optimized_parent_result: OptimizeBranchResult = node.parent_node.accept(self) if optimized_parent_result.base_output_node is not None: - return OptimizeBranchResult[SourceDataSetT]( + return OptimizeBranchResult( base_output_node=ComputeMetricsNode( parent_node=optimized_parent_result.base_output_node, metric_specs=node.metric_specs, @@ -189,39 +179,29 @@ def visit_compute_metrics_node( # noqa: D return OptimizeBranchResult(base_output_node=node) - def visit_order_by_limit_node( # noqa: D - self, node: OrderByLimitNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_order_by_limit_node(self, node: OrderByLimitNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) return self._default_base_output_handler(node) - def visit_where_constraint_node( # noqa: D - self, node: WhereConstraintNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) return self._default_base_output_handler(node) - def visit_write_to_result_dataframe_node( # noqa: D - self, node: WriteToResultDataframeNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_write_to_result_dataframe_node(self, node: WriteToResultDataframeNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) return self._default_sink_node_handler(node) - def visit_write_to_result_table_node( # noqa: D - self, node: WriteToResultTableNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) return self._default_sink_node_handler(node) - def visit_pass_elements_filter_node( # noqa: D - self, node: FilterElementsNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_pass_elements_filter_node(self, node: FilterElementsNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) return self._default_base_output_handler(node) @staticmethod def _combine_branches( - left_branches: Sequence[BaseOutput[SourceDataSetT]], right_branch: BaseOutput[SourceDataSetT] + left_branches: Sequence[BaseOutput], right_branch: BaseOutput ) -> Sequence[BranchCombinationResult]: """Combine the right branch with one of the left branches. @@ -256,9 +236,7 @@ def _combine_branches( ) return results - def visit_combine_metrics_node( # noqa: D - self, node: CombineMetricsNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_combine_metrics_node(self, node: CombineMetricsNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) # The parent node of the CombineMetricsNode can be either ComputeMetricsNodes or CombineMetricsNodes @@ -283,7 +261,7 @@ def visit_combine_metrics_node( # noqa: D # Try to combine (using ComputeMetricsBranchCombiner) as many parent branches as possible in a # greedy N^2 approach. The optimality of this approach needs more thought to prove conclusively, but given # the seemingly transitive properties of the combination operation, this seems reasonable. - combined_parent_branches: List[BaseOutput[SourceDataSetT]] = [] + combined_parent_branches: List[BaseOutput] = [] for optimized_parent_branch in optimized_parent_branches: combination_results = SourceScanOptimizer._combine_branches( left_branches=combined_parent_branches, right_branch=optimized_parent_branch @@ -307,38 +285,32 @@ def visit_combine_metrics_node( # noqa: D # If we were able to reduce the parent branches of the CombineMetricsNode into a single one, there's no need # for a CombineMetricsNode. if len(combined_parent_branches) == 1: - return OptimizeBranchResult[SourceDataSetT](base_output_node=combined_parent_branches[0]) + return OptimizeBranchResult(base_output_node=combined_parent_branches[0]) - return OptimizeBranchResult[SourceDataSetT]( + return OptimizeBranchResult( base_output_node=CombineMetricsNode(parent_nodes=combined_parent_branches, join_type=node.join_type) ) - def visit_constrain_time_range_node( # noqa: D - self, node: ConstrainTimeRangeNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) return self._default_base_output_handler(node) - def visit_join_over_time_range_node( # noqa: D - self, node: JoinOverTimeRangeNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) return self._default_base_output_handler(node) - def visit_semi_additive_join_node( # noqa: D - self, node: SemiAdditiveJoinNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) return self._default_base_output_handler(node) def visit_metric_time_dimension_transform_node( # noqa: D - self, node: MetricTimeDimensionTransformNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + self, node: MetricTimeDimensionTransformNode + ) -> OptimizeBranchResult: self._log_visit_node_type(node) return self._default_base_output_handler(node) - def optimize(self, dataflow_plan: DataflowPlan[SourceDataSetT]) -> DataflowPlan[SourceDataSetT]: # noqa: D - optimized_result: OptimizeBranchResult[SourceDataSetT] = dataflow_plan.sink_output_node.accept(self) + def optimize(self, dataflow_plan: DataflowPlan) -> DataflowPlan: # noqa: D + optimized_result: OptimizeBranchResult = dataflow_plan.sink_output_node.accept(self) logger.log( level=self._log_level, @@ -351,18 +323,16 @@ def optimize(self, dataflow_plan: DataflowPlan[SourceDataSetT]) -> DataflowPlan[ plan_id = IdGeneratorRegistry.for_class(self.__class__).create_id(OPTIMIZED_DATAFLOW_PLAN_PREFIX) logger.log(level=self._log_level, msg=f"Optimized plan ID is {plan_id}") if optimized_result.sink_node: - return DataflowPlan[SourceDataSetT]( + return DataflowPlan( plan_id=plan_id, sink_output_nodes=[optimized_result.sink_node], ) logger.log(level=self._log_level, msg="Optimizer didn't produce a result, so returning the same plan") - return DataflowPlan[SourceDataSetT]( + return DataflowPlan( plan_id=plan_id, sink_output_nodes=[dataflow_plan.sink_output_node], ) - def visit_join_to_time_spine_node( # noqa: D - self, node: JoinToTimeSpineNode[SourceDataSetT] - ) -> OptimizeBranchResult[SourceDataSetT]: + def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> OptimizeBranchResult: # noqa: D self._log_visit_node_type(node) return self._default_base_output_handler(node) diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index d08163c5d4..1d459bca78 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -157,7 +157,7 @@ class MetricFlowQueryResult: # noqa: D """The result of a query and context on how it was generated.""" query_spec: MetricFlowQuerySpec - dataflow_plan: DataflowPlan[SemanticModelDataSet] + dataflow_plan: DataflowPlan sql: str result_df: Optional[pd.DataFrame] = None result_table: Optional[SqlTable] = None @@ -168,7 +168,7 @@ class MetricFlowExplainResult: """Returns plans for resolving a query.""" query_spec: MetricFlowQuerySpec - dataflow_plan: DataflowPlan[SemanticModelDataSet] + dataflow_plan: DataflowPlan execution_plan: ExecutionPlan output_table: Optional[SqlTable] = None @@ -348,20 +348,20 @@ def __init__( source_node_builder = SourceNodeBuilder(self._semantic_manifest_lookup) source_nodes = source_node_builder.create_from_data_sets(self._source_data_sets) - node_output_resolver = DataflowPlanNodeOutputDataSetResolver[SemanticModelDataSet]( + node_output_resolver = DataflowPlanNodeOutputDataSetResolver( column_association_resolver=DunderColumnAssociationResolver(semantic_manifest_lookup), semantic_manifest_lookup=semantic_manifest_lookup, ) - self._dataflow_plan_builder = DataflowPlanBuilder[SemanticModelDataSet]( + self._dataflow_plan_builder = DataflowPlanBuilder( source_nodes=source_nodes, semantic_manifest_lookup=self._semantic_manifest_lookup, ) - self._to_sql_query_plan_converter = DataflowToSqlQueryPlanConverter[SemanticModelDataSet]( + self._to_sql_query_plan_converter = DataflowToSqlQueryPlanConverter( column_association_resolver=self._column_association_resolver, semantic_manifest_lookup=self._semantic_manifest_lookup, ) - self._to_execution_plan_converter = DataflowToExecutionPlanConverter[SemanticModelDataSet]( + self._to_execution_plan_converter = DataflowToExecutionPlanConverter( sql_plan_converter=self._to_sql_query_plan_converter, sql_plan_renderer=self._sql_client.sql_query_plan_renderer, sql_client=sql_client, @@ -482,7 +482,7 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me query_spec=query_spec, output_sql_table=output_table, output_selection_specs=output_selection_specs, - optimizers=(SourceScanOptimizer[SemanticModelDataSet](),), + optimizers=(SourceScanOptimizer(),), ) if len(dataflow_plan.sink_output_nodes) > 1: diff --git a/metricflow/model/data_warehouse_model_validator.py b/metricflow/model/data_warehouse_model_validator.py index c1c37f3a49..b33f9715e8 100644 --- a/metricflow/model/data_warehouse_model_validator.py +++ b/metricflow/model/data_warehouse_model_validator.py @@ -34,7 +34,6 @@ from metricflow.dataflow.dataflow_plan import BaseOutput, FilterElementsNode from metricflow.dataset.convert_semantic_model import SemanticModelToDataSetConverter from metricflow.dataset.dataset import DataSet -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet from metricflow.engine.metricflow_engine import MetricFlowEngine, MetricFlowExplainResult, MetricFlowQueryRequest from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver @@ -72,7 +71,7 @@ def __init__(self, manifest: SemanticManifest) -> None: # noqa: D column_association_resolver=DunderColumnAssociationResolver(self.semantic_manifest_lookup), semantic_manifest_lookup=self.semantic_manifest_lookup, ) - self.node_resolver = DataflowPlanNodeOutputDataSetResolver[SemanticModelDataSet]( + self.node_resolver = DataflowPlanNodeOutputDataSetResolver( column_association_resolver=DunderColumnAssociationResolver(self.semantic_manifest_lookup), semantic_manifest_lookup=self.semantic_manifest_lookup, ) @@ -100,9 +99,7 @@ def _remove_entity_link_specs(specs: Tuple[LinkableInstanceSpecT, ...]) -> Tuple return tuple(spec for spec in specs if not spec.entity_links) @staticmethod - def _semantic_model_nodes( - render_tools: QueryRenderingTools, semantic_model: SemanticModel - ) -> Sequence[BaseOutput[SemanticModelDataSet]]: + def _semantic_model_nodes(render_tools: QueryRenderingTools, semantic_model: SemanticModel) -> Sequence[BaseOutput]: """Builds and returns the SemanticModelDataSet node for the given semantic model.""" fetched_semantic_model = render_tools.semantic_manifest_lookup.semantic_model_lookup.get_by_reference( SemanticModelReference(semantic_model_name=semantic_model.name) @@ -345,7 +342,7 @@ def gen_measure_tasks(cls, manifest: SemanticManifest, sql_client: SqlClient) -> dataset = render_tools.converter.create_sql_source_data_set(semantic_model) semantic_model_specs = dataset.instance_set.spec_set.measure_specs - source_node_by_measure_spec: Dict[MeasureSpec, BaseOutput[SemanticModelDataSet]] = {} + source_node_by_measure_spec: Dict[MeasureSpec, BaseOutput] = {} measure_specs_source_node_pair = [] for source_node in source_nodes: measure_specs = render_tools.node_resolver.get_output_data_set( @@ -355,7 +352,7 @@ def gen_measure_tasks(cls, manifest: SemanticManifest, sql_client: SqlClient) -> measure_specs_source_node_pair.append((measure_specs, source_node)) source_node_to_sub_task: DefaultDict[ - BaseOutput[SemanticModelDataSet], List[DataWarehouseValidationTask] + BaseOutput, List[DataWarehouseValidationTask] ] = collections.defaultdict(list) for spec in semantic_model_specs: obtained_source_node = source_node_by_measure_spec.get(spec) diff --git a/metricflow/plan_conversion/dataflow_to_execution.py b/metricflow/plan_conversion/dataflow_to_execution.py index 9594dfb627..849040588c 100644 --- a/metricflow/plan_conversion/dataflow_to_execution.py +++ b/metricflow/plan_conversion/dataflow_to_execution.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Generic, Optional, Union +from typing import Optional, Union from metricflow.dag.id_generation import EXEC_PLAN_PREFIX, SQL_QUERY_PLAN_PREFIX, IdGeneratorRegistry from metricflow.dataflow.dataflow_plan import ( @@ -9,7 +9,6 @@ ComputedMetricsOutput, DataflowPlan, SinkNodeVisitor, - SourceDataSetT, WriteToResultDataframeNode, WriteToResultTableNode, ) @@ -20,7 +19,7 @@ SelectSqlQueryToDataFrameTask, SelectSqlQueryToTableTask, ) -from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter, SqlDataSetT +from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter from metricflow.protocols.sql_client import SqlClient from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer from metricflow.sql.sql_plan import SqlQueryPlan @@ -30,12 +29,12 @@ logger = logging.getLogger(__name__) -class DataflowToExecutionPlanConverter(Generic[SqlDataSetT], SinkNodeVisitor[SqlDataSetT, ExecutionPlan]): +class DataflowToExecutionPlanConverter(SinkNodeVisitor[ExecutionPlan]): """Converts a dataflow plan to an execution plan.""" def __init__( self, - sql_plan_converter: DataflowToSqlQueryPlanConverter[SqlDataSetT], + sql_plan_converter: DataflowToSqlQueryPlanConverter, sql_plan_renderer: SqlQueryPlanRenderer, sql_client: SqlClient, extra_sql_tags: SqlJsonTag = SqlJsonTag(), @@ -55,7 +54,7 @@ def __init__( def _build_execution_plan( # noqa: D self, - node: Union[BaseOutput[SourceDataSetT], ComputedMetricsOutput[SourceDataSetT]], + node: Union[BaseOutput, ComputedMetricsOutput], output_table: Optional[SqlTable] = None, ) -> ExecutionPlan: sql_plan = self._sql_plan_converter.convert_to_sql_query_plan( @@ -90,15 +89,11 @@ def _build_execution_plan( # noqa: D plan_id=IdGeneratorRegistry.for_class(self.__class__).create_id(EXEC_PLAN_PREFIX), leaf_tasks=[leaf_task] ) - def visit_write_to_result_dataframe_node( # noqa: D - self, node: WriteToResultDataframeNode[SourceDataSetT] - ) -> ExecutionPlan: + def visit_write_to_result_dataframe_node(self, node: WriteToResultDataframeNode) -> ExecutionPlan: # noqa: D logger.info(f"Generating SQL query plan from {node.node_id} -> {node.parent_node.node_id}") return self._build_execution_plan(node.parent_node) - def visit_write_to_result_table_node( # noqa: D - self, node: WriteToResultTableNode[SourceDataSetT] - ) -> ExecutionPlan: + def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> ExecutionPlan: # noqa: D logger.info(f"Generating SQL query plan from {node.node_id} -> {node.parent_node.node_id}") return self._build_execution_plan(node.parent_node, node.output_sql_table) diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index d350ff17e8..74d0249f18 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -2,7 +2,7 @@ import logging from collections import OrderedDict -from typing import Generic, List, Optional, Sequence, TypeVar, Union +from typing import List, Optional, Sequence, Union from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.protocols.metric import MetricType @@ -28,7 +28,6 @@ OrderByLimitNode, ReadSqlSourceNode, SemiAdditiveJoinNode, - SourceDataSetT, WhereConstraintNode, WriteToResultDataframeNode, WriteToResultTableNode, @@ -111,10 +110,6 @@ logger = logging.getLogger(__name__) -# The type of data set that present at a source node. -SqlDataSetT = TypeVar("SqlDataSetT", bound=SqlDataSet) - - def _make_time_range_comparison_expr( table_alias: str, column_alias: str, time_range_constraint: TimeRangeConstraint ) -> SqlExpressionNode: @@ -136,7 +131,7 @@ def _make_time_range_comparison_expr( ) -class DataflowToSqlQueryPlanConverter(Generic[SqlDataSetT], DataflowPlanNodeVisitor[SqlDataSetT, SqlDataSet]): +class DataflowToSqlQueryPlanConverter(DataflowPlanNodeVisitor[SqlDataSet]): """Generates an SQL query plan from a node in the a metric dataflow plan.""" def __init__( @@ -263,14 +258,14 @@ def _make_time_spine_data_set( ), ) - def visit_source_node(self, node: ReadSqlSourceNode[SqlDataSetT]) -> SqlDataSet: + def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet: """Generate the SQL to read from the source.""" return SqlDataSet( sql_select_node=node.data_set.sql_select_node, instance_set=node.data_set.instance_set, ) - def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode[SqlDataSetT]) -> SqlDataSet: + def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDataSet: """Generate time range join SQL.""" table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict() @@ -364,7 +359,7 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode[SqlDataSet ), ) - def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode[SqlDataSetT]) -> SqlDataSet: + def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> SqlDataSet: """Generates the query that realizes the behavior of the JoinToStandardOutputNode.""" # Keep a mapping between the table aliases that would be used in the query and the MDO instances in that source. # e.g. when building "FROM from_table a JOIN right_table b", the value for key "a" would be the instances in @@ -785,7 +780,7 @@ def visit_write_to_result_dataframe_node(self, node: WriteToResultDataframeNode) """ raise RuntimeError("This node type is not supported.") - def visit_write_to_result_table_node(self, node: WriteToResultTableNode[SourceDataSetT]) -> SqlDataSet: + def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> SqlDataSet: """Similar to visit_write_to_result_dataframe_node().""" raise RuntimeError("This node type is not supported.") @@ -888,7 +883,7 @@ def _make_select_columns_for_metrics( ) return select_columns - def visit_combine_metrics_node(self, node: CombineMetricsNode[SqlDataSetT]) -> SqlDataSet: + def visit_combine_metrics_node(self, node: CombineMetricsNode) -> SqlDataSet: """Join computed metric datasets together to return a single dataset containing all metrics. This node may exist in one of two situations: when metrics need to be combined in order to produce a single @@ -1000,7 +995,7 @@ def visit_combine_metrics_node(self, node: CombineMetricsNode[SqlDataSetT]) -> S ), ) - def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode[SourceDataSetT]) -> SqlDataSet: + def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> SqlDataSet: """Convert ConstrainTimeRangeNode to a SqlDataSet by building the time constraint comparison. Use the smallest time granularity to build the comparison since that's what was used in the semantic model @@ -1074,9 +1069,7 @@ def convert_to_sql_query_plan( return SqlQueryPlan(plan_id=sql_query_plan_id, render_node=sql_select_node) - def visit_metric_time_dimension_transform_node( - self, node: MetricTimeDimensionTransformNode[SqlDataSetT] - ) -> SqlDataSet: + def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTransformNode) -> SqlDataSet: """Implement the behavior of the MetricTimeDimensionTransformNode. This node will create an output data set that is similar to the input data set, but the measure instances it @@ -1294,7 +1287,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe ), ) - def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode[SourceDataSetT]) -> SqlDataSet: # noqa: D + def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet: # noqa: D parent_data_set = node.parent_node.accept(self) parent_alias = self._next_unique_table_alias() diff --git a/metricflow/plan_conversion/node_processor.py b/metricflow/plan_conversion/node_processor.py index 177b366ff0..febcb3d5ac 100644 --- a/metricflow/plan_conversion/node_processor.py +++ b/metricflow/plan_conversion/node_processor.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass -from typing import Generic, List, Optional, Sequence, Set, TypeVar +from typing import List, Optional, Sequence, Set from dbt_semantic_interfaces.pretty_print import pformat_big_objects from dbt_semantic_interfaces.references import EntityReference, TimeDimensionReference @@ -18,19 +18,15 @@ ) from metricflow.filters.time_constraint import TimeRangeConstraint from metricflow.model.semantics.semantic_model_join_evaluator import MAX_JOIN_HOPS, SemanticModelJoinEvaluator -from metricflow.plan_conversion.sql_dataset import SqlDataSet from metricflow.protocols.semantics import SemanticModelAccessor from metricflow.specs.spec_set_transforms import ToElementNameSet from metricflow.specs.specs import InstanceSpecSet, LinkableInstanceSpec, LinklessEntitySpec -SqlDataSetT = TypeVar("SqlDataSetT", bound=SqlDataSet) - - logger = logging.getLogger(__name__) @dataclass(frozen=True) -class MultiHopJoinCandidateLineage(Generic[SqlDataSetT]): +class MultiHopJoinCandidateLineage: """Describes how the multi-hop join candidate was formed. For example, if @@ -44,23 +40,23 @@ class MultiHopJoinCandidateLineage(Generic[SqlDataSetT]): to get the country dimension. """ - first_node_to_join: BaseOutput[SqlDataSetT] - second_node_to_join: BaseOutput[SqlDataSetT] + first_node_to_join: BaseOutput + second_node_to_join: BaseOutput join_second_node_by_entity: LinklessEntitySpec @dataclass(frozen=True) -class MultiHopJoinCandidate(Generic[SqlDataSetT]): +class MultiHopJoinCandidate: """A candidate node containing linkable specs that is join of other nodes. It's used to resolve multi-hop queries. Also see MultiHopJoinCandidateLineage. """ - node_with_multi_hop_elements: BaseOutput[SqlDataSetT] - lineage: MultiHopJoinCandidateLineage[SqlDataSetT] + node_with_multi_hop_elements: BaseOutput + lineage: MultiHopJoinCandidateLineage -class PreDimensionJoinNodeProcessor(Generic[SqlDataSetT]): +class PreDimensionJoinNodeProcessor: """Processes source nodes before measures are joined to dimensions. Generally, the source nodes will be combined with other dataflow plan nodes to produce a new set of nodes to realize @@ -82,7 +78,7 @@ class PreDimensionJoinNodeProcessor(Generic[SqlDataSetT]): def __init__( # noqa: D self, semantic_model_lookup: SemanticModelAccessor, - node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver[SqlDataSetT], + node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver, ): self._node_data_set_resolver = node_data_set_resolver self._partition_resolver = PartitionJoinResolver(semantic_model_lookup) @@ -91,12 +87,12 @@ def __init__( # noqa: D def add_time_range_constraint( self, - source_nodes: Sequence[BaseOutput[SqlDataSetT]], + source_nodes: Sequence[BaseOutput], metric_time_dimension_reference: TimeDimensionReference, time_range_constraint: Optional[TimeRangeConstraint] = None, - ) -> Sequence[BaseOutput[SqlDataSetT]]: + ) -> Sequence[BaseOutput]: """Adds a time range constraint node to the input nodes.""" - processed_nodes: List[BaseOutput[SqlDataSetT]] = [] + processed_nodes: List[BaseOutput] = [] for source_node in source_nodes: # Constrain the time range if specified. if time_range_constraint: @@ -121,7 +117,7 @@ def add_time_range_constraint( def _node_contains_entity( self, - node: BaseOutput[SqlDataSetT], + node: BaseOutput, entity_reference: EntityReference, ) -> bool: """Returns true if the output of the node contains an entity of the given types.""" @@ -155,7 +151,7 @@ def _node_contains_entity( def _get_candidates_nodes_for_multi_hop( self, desired_linkable_spec: LinkableInstanceSpec, - nodes: Sequence[BaseOutput[SqlDataSetT]], + nodes: Sequence[BaseOutput], ) -> Sequence[MultiHopJoinCandidate]: """Assemble nodes representing all possible one-hop joins.""" if len(desired_linkable_spec.entity_links) > MAX_JOIN_HOPS: @@ -280,11 +276,11 @@ def _get_candidates_nodes_for_multi_hop( return multi_hop_join_candidates def add_multi_hop_joins( - self, desired_linkable_specs: Sequence[LinkableInstanceSpec], nodes: Sequence[BaseOutput[SqlDataSetT]] - ) -> Sequence[BaseOutput[SqlDataSetT]]: + self, desired_linkable_specs: Sequence[LinkableInstanceSpec], nodes: Sequence[BaseOutput] + ) -> Sequence[BaseOutput]: """Assemble nodes representing all possible one-hop joins.""" - all_multi_hop_join_candidates: List[MultiHopJoinCandidate[SqlDataSetT]] = [] - lineage_for_all_multi_hop_join_candidates: Set[MultiHopJoinCandidateLineage[SqlDataSetT]] = set() + all_multi_hop_join_candidates: List[MultiHopJoinCandidate] = [] + lineage_for_all_multi_hop_join_candidates: Set[MultiHopJoinCandidateLineage] = set() for desired_linkable_spec in desired_linkable_specs: for multi_hop_join_candidate in self._get_candidates_nodes_for_multi_hop( @@ -301,9 +297,9 @@ def add_multi_hop_joins( def remove_unnecessary_nodes( self, desired_linkable_specs: Sequence[LinkableInstanceSpec], - nodes: Sequence[BaseOutput[SqlDataSetT]], + nodes: Sequence[BaseOutput], metric_time_dimension_reference: TimeDimensionReference, - ) -> Sequence[BaseOutput[SqlDataSetT]]: + ) -> Sequence[BaseOutput]: """Filters out many of the nodes that can't possibly be useful for joins to obtain the desired linkable specs. A simple filter is to remove any nodes that don't share a common element with the query. Having a common element diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index 079c0d3a1e..39d06142b5 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional, Sequence, Tuple, TypeVar +from typing import List, Optional, Sequence, Tuple from metricflow.dataflow.dataflow_plan import JoinDescription, JoinOverTimeRangeNode, JoinToTimeSpineNode from metricflow.plan_conversion.sql_dataset import SqlDataSet @@ -19,8 +19,6 @@ ) from metricflow.sql.sql_plan import SqlExpressionNode, SqlJoinDescription, SqlJoinType, SqlSelectStatementNode -SqlDataSetT = TypeVar("SqlDataSetT", bound=SqlDataSet) - @dataclass(frozen=True) class ColumnEqualityDescription: @@ -406,7 +404,7 @@ def _make_equality_expression_for_full_outer_join( @staticmethod def make_cumulative_metric_time_range_join_description( - node: JoinOverTimeRangeNode[SqlDataSetT], + node: JoinOverTimeRangeNode, metric_data_set: AnnotatedSqlDataSet, time_spine_data_set: AnnotatedSqlDataSet, ) -> SqlJoinDescription: diff --git a/metricflow/query/query_parser.py b/metricflow/query/query_parser.py index b3e79b4e45..0027e68322 100644 --- a/metricflow/query/query_parser.py +++ b/metricflow/query/query_parser.py @@ -24,7 +24,6 @@ from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver from metricflow.dataflow.dataflow_plan import BaseOutput from metricflow.dataset.dataset import DataSet -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet from metricflow.errors.errors import UnableToSatisfyQueryError from metricflow.filters.time_constraint import TimeRangeConstraint from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup @@ -113,8 +112,8 @@ def __init__( # noqa: D self, column_association_resolver: ColumnAssociationResolver, model: SemanticManifestLookup, - source_nodes: Sequence[BaseOutput[SemanticModelDataSet]], - node_output_resolver: DataflowPlanNodeOutputDataSetResolver[SemanticModelDataSet], + source_nodes: Sequence[BaseOutput], + node_output_resolver: DataflowPlanNodeOutputDataSetResolver, ) -> None: self._column_association_resolver = column_association_resolver self._model = model diff --git a/metricflow/test/dataflow/builder/test_costing.py b/metricflow/test/dataflow/builder/test_costing.py index c5f7a22829..fa4a1a1621 100644 --- a/metricflow/test/dataflow/builder/test_costing.py +++ b/metricflow/test/dataflow/builder/test_costing.py @@ -9,7 +9,6 @@ JoinDescription, JoinToBaseOutputNode, ) -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet from metricflow.specs.specs import ( DimensionSpec, EntitySpec, @@ -30,7 +29,7 @@ def test_costing(consistent_id_object_repository: ConsistentIdObjectRepository) bookings_spec = MeasureSpec( element_name="bookings", ) - bookings_filtered = FilterElementsNode[SemanticModelDataSet]( + bookings_filtered = FilterElementsNode( parent_node=bookings_node, include_specs=InstanceSpecSet( measure_specs=(bookings_spec,), @@ -43,7 +42,7 @@ def test_costing(consistent_id_object_repository: ConsistentIdObjectRepository) ), ) - listings_filtered = FilterElementsNode[SemanticModelDataSet]( + listings_filtered = FilterElementsNode( parent_node=listings_node, include_specs=InstanceSpecSet( dimension_specs=( @@ -61,7 +60,7 @@ def test_costing(consistent_id_object_repository: ConsistentIdObjectRepository) ), ) - join_node = JoinToBaseOutputNode[SemanticModelDataSet]( + join_node = JoinToBaseOutputNode( left_node=bookings_filtered, join_targets=[ JoinDescription( @@ -73,11 +72,11 @@ def test_costing(consistent_id_object_repository: ConsistentIdObjectRepository) ], ) - bookings_aggregated = AggregateMeasuresNode[SemanticModelDataSet]( + bookings_aggregated = AggregateMeasuresNode( parent_node=join_node, metric_input_measure_specs=(MetricInputMeasureSpec(measure_spec=bookings_spec),) ) - cost_function = DefaultCostFunction[SemanticModelDataSet]() + cost_function = DefaultCostFunction() cost = cost_function.calculate_cost(bookings_aggregated) assert cost == DefaultCost(num_joins=1, num_aggregations=1) diff --git a/metricflow/test/dataflow/builder/test_cyclic_join.py b/metricflow/test/dataflow/builder/test_cyclic_join.py index 5c141c962c..8762e45ce3 100644 --- a/metricflow/test/dataflow/builder/test_cyclic_join.py +++ b/metricflow/test/dataflow/builder/test_cyclic_join.py @@ -9,7 +9,6 @@ from metricflow.dataflow.builder.costing import DefaultCostFunction from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder from metricflow.dataflow.dataflow_plan_to_text import dataflow_plan_as_text -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.specs.specs import ( DimensionSpec, @@ -29,21 +28,21 @@ def cyclic_join_manifest_dataflow_plan_builder( # noqa: D cyclic_join_semantic_manifest_lookup: SemanticManifestLookup, consistent_id_object_repository: ConsistentIdObjectRepository, -) -> DataflowPlanBuilder[SemanticModelDataSet]: +) -> DataflowPlanBuilder: for source_node in consistent_id_object_repository.cyclic_join_source_nodes: logger.error(f"Source node is: {source_node}") return DataflowPlanBuilder( source_nodes=consistent_id_object_repository.cyclic_join_source_nodes, semantic_manifest_lookup=cyclic_join_semantic_manifest_lookup, - cost_function=DefaultCostFunction[SemanticModelDataSet](), + cost_function=DefaultCostFunction(), ) def test_cyclic_join( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - cyclic_join_manifest_dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + cyclic_join_manifest_dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests that sources with the same joinable keys don't cause cycle issues.""" dataflow_plan = cyclic_join_manifest_dataflow_plan_builder.build_plan( diff --git a/metricflow/test/dataflow/builder/test_dataflow_plan_builder.py b/metricflow/test/dataflow/builder/test_dataflow_plan_builder.py index 4c6d7dbcfb..625226cddf 100644 --- a/metricflow/test/dataflow/builder/test_dataflow_plan_builder.py +++ b/metricflow/test/dataflow/builder/test_dataflow_plan_builder.py @@ -10,7 +10,6 @@ from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder from metricflow.dataflow.dataflow_plan_to_text import dataflow_plan_as_text from metricflow.dataset.dataset import DataSet -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet from metricflow.errors.errors import UnableToSatisfyQueryError from metricflow.specs.column_assoc import ColumnAssociationResolver from metricflow.specs.specs import ( @@ -32,7 +31,7 @@ def test_simple_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a simple plan getting a metric and a local dimension.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -64,7 +63,7 @@ def test_simple_plan( # noqa: D def test_primary_entity_dimension( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a simple plan getting a metric and a local dimension.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -97,7 +96,7 @@ def test_joined_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, column_association_resolver: ColumnAssociationResolver, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a plan getting a measure and a joined dimension.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -133,7 +132,7 @@ def test_joined_plan( # noqa: D def test_order_by_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a plan with an order by.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -170,7 +169,7 @@ def test_order_by_plan( # noqa: D def test_limit_rows_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a plan with a limit to the number of rows returned.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -198,7 +197,7 @@ def test_limit_rows_plan( # noqa: D def test_multiple_metrics_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a plan to retrieve multiple metrics.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -231,7 +230,7 @@ def test_multiple_metrics_plan( # noqa: D def test_single_semantic_model_ratio_metrics_plan( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a plan to retrieve a ratio where both measures come from one semantic model.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -264,7 +263,7 @@ def test_single_semantic_model_ratio_metrics_plan( def test_multi_semantic_model_ratio_metrics_plan( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a plan to retrieve a ratio where both measures come from one semantic model.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -297,7 +296,7 @@ def test_multi_semantic_model_ratio_metrics_plan( def test_multihop_join_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - multihop_dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + multihop_dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a plan with an order by.""" dataflow_plan = multihop_dataflow_plan_builder.build_plan( @@ -333,7 +332,7 @@ def test_where_constrained_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, column_association_resolver: ColumnAssociationResolver, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a simple plan getting a metric and a local dimension.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -374,7 +373,7 @@ def test_where_constrained_plan( # noqa: D def test_where_constrained_plan_time_dimension( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, column_association_resolver: ColumnAssociationResolver, ) -> None: """Tests a simple plan getting a metric and a local dimension.""" @@ -417,7 +416,7 @@ def test_where_constrained_with_common_linkable_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, column_association_resolver: ColumnAssociationResolver, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a dataflow plan where the where clause has a common linkable with the query.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -457,7 +456,7 @@ def test_where_constrained_with_common_linkable_plan( # noqa: D def test_multihop_join_plan_ambiguous_dim( # noqa: D mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Checks that an exception is thrown when trying to build a plan with an ambiguous dimension.""" with pytest.raises(UnableToSatisfyQueryError): @@ -480,7 +479,7 @@ def test_multihop_join_plan_ambiguous_dim( # noqa: D def test_cumulative_metric_with_window( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a plan to compute a cumulative metric.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -508,7 +507,7 @@ def test_cumulative_metric_with_window( # noqa: D def test_cumulative_metric_no_window_or_grain_with_metric_time( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( MetricFlowQuerySpec( @@ -535,7 +534,7 @@ def test_cumulative_metric_no_window_or_grain_with_metric_time( # noqa: D def test_cumulative_metric_no_window_or_grain_without_metric_time( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( MetricFlowQuerySpec( @@ -562,7 +561,7 @@ def test_cumulative_metric_no_window_or_grain_without_metric_time( # noqa: D def test_distinct_values_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a plan to get distinct values of a dimension.""" dataflow_plan = dataflow_plan_builder.build_plan_for_distinct_values( @@ -591,7 +590,7 @@ def test_distinct_values_plan( # noqa: D def test_measure_constraint_plan( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a plan for querying a metric with a constraint on one or more of its input measures.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -619,7 +618,7 @@ def test_measure_constraint_plan( def test_measure_constraint_with_reused_measure_plan( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a plan for querying a metric with a constraint on one or more of its input measures.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -647,7 +646,7 @@ def test_measure_constraint_with_reused_measure_plan( def test_common_semantic_model( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a simple plan getting a metric and a local dimension.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -677,7 +676,7 @@ def test_common_semantic_model( # noqa: D def test_derived_metric_offset_window( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a simple plan getting a metric and a local dimension.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -704,7 +703,7 @@ def test_derived_metric_offset_window( # noqa: D def test_derived_metric_offset_to_grain( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests a simple plan getting a metric and a local dimension.""" dataflow_plan = dataflow_plan_builder.build_plan( @@ -731,7 +730,7 @@ def test_derived_metric_offset_to_grain( # noqa: D def test_derived_metric_offset_with_granularity( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( MetricFlowQuerySpec( @@ -757,7 +756,7 @@ def test_derived_metric_offset_with_granularity( # noqa: D def test_derived_offset_cumulative_metric( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( MetricFlowQuerySpec( diff --git a/metricflow/test/dataflow/builder/test_node_evaluator.py b/metricflow/test/dataflow/builder/test_node_evaluator.py index 08e5af3028..523cb416ed 100644 --- a/metricflow/test/dataflow/builder/test_node_evaluator.py +++ b/metricflow/test/dataflow/builder/test_node_evaluator.py @@ -16,7 +16,6 @@ from metricflow.dataflow.builder.partitions import PartitionTimeDimensionJoinDescription from metricflow.dataflow.dataflow_plan import BaseOutput, ValidityWindowJoinDescription from metricflow.dataset.dataset import DataSet -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver from metricflow.plan_conversion.node_processor import PreDimensionJoinNodeProcessor @@ -37,7 +36,7 @@ def node_evaluator( consistent_id_object_repository: ConsistentIdObjectRepository, simple_semantic_manifest_lookup: SemanticManifestLookup, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> NodeEvaluatorForLinkableInstances: # noqa: D """Return a node evaluator using the nodes in semantic_model_name_to_nodes.""" node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver = DataflowPlanNodeOutputDataSetResolver( @@ -56,7 +55,7 @@ def node_evaluator( def make_multihop_node_evaluator( - model_source_nodes: Sequence[BaseOutput[SemanticModelDataSet]], + model_source_nodes: Sequence[BaseOutput], semantic_manifest_lookup_with_multihop_links: SemanticManifestLookup, desired_linkable_specs: Sequence[LinkableInstanceSpec], ) -> NodeEvaluatorForLinkableInstances: # noqa: D diff --git a/metricflow/test/dataflow/optimizer/source_scan/test_cm_branch_combiner.py b/metricflow/test/dataflow/optimizer/source_scan/test_cm_branch_combiner.py index e6a36d8e3b..4ff31dd136 100644 --- a/metricflow/test/dataflow/optimizer/source_scan/test_cm_branch_combiner.py +++ b/metricflow/test/dataflow/optimizer/source_scan/test_cm_branch_combiner.py @@ -7,7 +7,6 @@ BaseOutput, DataflowPlan, FilterElementsNode, - SourceDataSetT, WriteToResultDataframeNode, ) from metricflow.dataflow.dataflow_plan_to_text import dataflow_plan_as_text @@ -22,10 +21,10 @@ from metricflow.test.snapshot_utils import assert_plan_snapshot_text_equal -def make_dataflow_plan(node: BaseOutput[SourceDataSetT]) -> DataflowPlan[SourceDataSetT]: # noqa: D - return DataflowPlan[SourceDataSetT]( +def make_dataflow_plan(node: BaseOutput) -> DataflowPlan: # noqa: D + return DataflowPlan( plan_id=IdGeneratorRegistry.for_class(ComputeMetricsBranchCombiner).create_id(OPTIMIZED_DATAFLOW_PLAN_PREFIX), - sink_output_nodes=[WriteToResultDataframeNode[SourceDataSetT](node)], + sink_output_nodes=[WriteToResultDataframeNode(node)], ) diff --git a/metricflow/test/dataflow/optimizer/source_scan/test_source_scan_optimizer.py b/metricflow/test/dataflow/optimizer/source_scan/test_source_scan_optimizer.py index a9370701ac..7ebf68edea 100644 --- a/metricflow/test/dataflow/optimizer/source_scan/test_source_scan_optimizer.py +++ b/metricflow/test/dataflow/optimizer/source_scan/test_source_scan_optimizer.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -from typing import Generic from _pytest.fixtures import FixtureRequest from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter @@ -25,7 +24,6 @@ OrderByLimitNode, ReadSqlSourceNode, SemiAdditiveJoinNode, - SourceDataSetT, WhereConstraintNode, WriteToResultDataframeNode, WriteToResultTableNode, @@ -33,7 +31,6 @@ from metricflow.dataflow.dataflow_plan_to_text import dataflow_plan_as_text from metricflow.dataflow.optimizer.source_scan.source_scan_optimizer import SourceScanOptimizer from metricflow.dataset.dataset import DataSet -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet from metricflow.specs.column_assoc import ColumnAssociationResolver from metricflow.specs.specs import ( DimensionSpec, @@ -49,72 +46,70 @@ logger = logging.getLogger(__name__) -class ReadSqlSourceNodeCounter(Generic[SourceDataSetT], DataflowPlanNodeVisitor[SourceDataSetT, int]): +class ReadSqlSourceNodeCounter(DataflowPlanNodeVisitor[int]): """Counts the number of ReadSqlSourceNodes in the dataflow plan.""" - def _sum_parents(self, node: DataflowPlanNode[SourceDataSetT]) -> int: + def _sum_parents(self, node: DataflowPlanNode) -> int: return sum(parent_node.accept(self) for parent_node in node.parent_nodes) - def visit_source_node(self, node: ReadSqlSourceNode[SourceDataSetT]) -> int: # noqa: D + def visit_source_node(self, node: ReadSqlSourceNode) -> int: # noqa: D return 1 - def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode[SourceDataSetT]) -> int: # noqa: D + def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> int: # noqa: D return self._sum_parents(node) def visit_join_aggregated_measures_by_groupby_columns_node( # noqa: D - self, node: JoinAggregatedMeasuresByGroupByColumnsNode[SourceDataSetT] + self, node: JoinAggregatedMeasuresByGroupByColumnsNode ) -> int: return self._sum_parents(node) - def visit_aggregate_measures_node(self, node: AggregateMeasuresNode[SourceDataSetT]) -> int: # noqa: D + def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> int: # noqa: D return self._sum_parents(node) - def visit_compute_metrics_node(self, node: ComputeMetricsNode[SourceDataSetT]) -> int: # noqa: D + def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> int: # noqa: D return self._sum_parents(node) - def visit_order_by_limit_node(self, node: OrderByLimitNode[SourceDataSetT]) -> int: # noqa: D + def visit_order_by_limit_node(self, node: OrderByLimitNode) -> int: # noqa: D return self._sum_parents(node) - def visit_where_constraint_node(self, node: WhereConstraintNode[SourceDataSetT]) -> int: # noqa: D + def visit_where_constraint_node(self, node: WhereConstraintNode) -> int: # noqa: D return self._sum_parents(node) - def visit_write_to_result_dataframe_node(self, node: WriteToResultDataframeNode[SourceDataSetT]) -> int: # noqa: D + def visit_write_to_result_dataframe_node(self, node: WriteToResultDataframeNode) -> int: # noqa: D return self._sum_parents(node) - def visit_write_to_result_table_node(self, node: WriteToResultTableNode[SourceDataSetT]) -> int: # noqa: D + def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> int: # noqa: D return self._sum_parents(node) - def visit_pass_elements_filter_node(self, node: FilterElementsNode[SourceDataSetT]) -> int: # noqa: D + def visit_pass_elements_filter_node(self, node: FilterElementsNode) -> int: # noqa: D return self._sum_parents(node) - def visit_combine_metrics_node(self, node: CombineMetricsNode[SourceDataSetT]) -> int: # noqa: D + def visit_combine_metrics_node(self, node: CombineMetricsNode) -> int: # noqa: D return self._sum_parents(node) - def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode[SourceDataSetT]) -> int: # noqa: D + def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> int: # noqa: D return self._sum_parents(node) - def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode[SourceDataSetT]) -> int: # noqa: D + def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> int: # noqa: D return self._sum_parents(node) - def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode[SourceDataSetT]) -> int: # noqa: D + def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> int: # noqa: D return self._sum_parents(node) - def visit_metric_time_dimension_transform_node( # noqa: D - self, node: MetricTimeDimensionTransformNode[SourceDataSetT] - ) -> int: + def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTransformNode) -> int: # noqa: D return self._sum_parents(node) - def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode[SourceDataSetT]) -> int: # noqa: D + def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> int: # noqa: D return self._sum_parents(node) - def count_source_nodes(self, dataflow_plan: DataflowPlan[SourceDataSetT]) -> int: # noqa: D + def count_source_nodes(self, dataflow_plan: DataflowPlan) -> int: # noqa: D return dataflow_plan.sink_output_node.accept(self) def check_optimization( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, query_spec: MetricFlowQuerySpec, expected_num_sources_in_unoptimized: int, expected_num_sources_in_optimized: int, @@ -134,10 +129,10 @@ def check_optimization( # noqa: D dag_graph=dataflow_plan, ) - source_counter = ReadSqlSourceNodeCounter[SemanticModelDataSet]() + source_counter = ReadSqlSourceNodeCounter() assert source_counter.count_source_nodes(dataflow_plan) == expected_num_sources_in_unoptimized - optimizer = SourceScanOptimizer[SemanticModelDataSet]() + optimizer = SourceScanOptimizer() optimized_dataflow_plan = optimizer.optimize(dataflow_plan) assert_plan_snapshot_text_equal( @@ -158,7 +153,7 @@ def check_optimization( # noqa: D def test_2_metrics_from_1_semantic_model( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests that optimizing the plan for 2 metrics from 2 measure semantic models results in half the number of scans. @@ -183,7 +178,7 @@ def test_2_metrics_from_1_semantic_model( # noqa: D def test_2_metrics_from_2_semantic_models( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests that 2 metrics from the 2 semantic models results in 2 scans.""" check_optimization( @@ -202,7 +197,7 @@ def test_2_metrics_from_2_semantic_models( # noqa: D def test_3_metrics_from_2_semantic_models( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests that 3 metrics from the 2 semantic models results in 2 scans.""" check_optimization( @@ -226,7 +221,7 @@ def test_constrained_metric_not_combined( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, column_association_resolver: ColumnAssociationResolver, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests that 2 metrics from the same semantic model but where 1 is constrained results in 2 scans. @@ -260,7 +255,7 @@ def test_constrained_metric_not_combined( # noqa: D def test_derived_metric( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests optimization of a query that use a derived metrics with measures coming from a single semantic model. @@ -282,7 +277,7 @@ def test_derived_metric( # noqa: D def test_nested_derived_metric( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests optimization of a query that use a nested derived metric from a single semantic model. @@ -305,7 +300,7 @@ def test_nested_derived_metric( # noqa: D def test_derived_metric_with_non_derived_metric( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests optimization of queries that use derived metrics and non-derived metrics. @@ -336,7 +331,7 @@ def test_derived_metric_with_non_derived_metric( # noqa: D def test_2_ratio_metrics_from_1_semantic_model( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, ) -> None: """Tests that 2 ratio metrics with measures from a 1 semantic model result in 1 scan.""" check_optimization( diff --git a/metricflow/test/examples/test_node_sql.py b/metricflow/test/examples/test_node_sql.py index 9ee0a49a00..ea929b8286 100644 --- a/metricflow/test/examples/test_node_sql.py +++ b/metricflow/test/examples/test_node_sql.py @@ -13,7 +13,6 @@ from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter -from metricflow.plan_conversion.sql_dataset import SqlDataSet from metricflow.protocols.sql_client import SqlClient from metricflow.specs.specs import InstanceSpecSet, TimeDimensionReference, TimeDimensionSpec from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel @@ -37,19 +36,19 @@ def test_view_sql_generated_at_a_node( ) to_data_set_converter = SemanticModelToDataSetConverter(column_association_resolver) - to_sql_plan_converter = DataflowToSqlQueryPlanConverter[SqlDataSet]( + to_sql_plan_converter = DataflowToSqlQueryPlanConverter( column_association_resolver=DunderColumnAssociationResolver(simple_semantic_manifest_lookup), semantic_manifest_lookup=simple_semantic_manifest_lookup, ) sql_renderer: SqlQueryPlanRenderer = sql_client.sql_query_plan_renderer - node_output_resolver = DataflowPlanNodeOutputDataSetResolver[SqlDataSet]( + node_output_resolver = DataflowPlanNodeOutputDataSetResolver( column_association_resolver=column_association_resolver, semantic_manifest_lookup=simple_semantic_manifest_lookup, ) # Show SQL and spec set at a source node. bookings_source_data_set = to_data_set_converter.create_sql_source_data_set(bookings_semantic_model) - read_source_node = ReadSqlSourceNode[SqlDataSet](bookings_source_data_set) + read_source_node = ReadSqlSourceNode(bookings_source_data_set) sql_plan_at_read_node = to_sql_plan_converter.convert_to_sql_query_plan( sql_engine_type=sql_client.sql_engine_type, sql_query_plan_id="example_sql_plan", diff --git a/metricflow/test/fixtures/dataflow_fixtures.py b/metricflow/test/fixtures/dataflow_fixtures.py index 3f2c99927f..06616aa7f1 100644 --- a/metricflow/test/fixtures/dataflow_fixtures.py +++ b/metricflow/test/fixtures/dataflow_fixtures.py @@ -4,7 +4,6 @@ from metricflow.dataflow.builder.costing import DefaultCostFunction from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver from metricflow.plan_conversion.time_spine import TimeSpineSource @@ -32,11 +31,11 @@ def column_association_resolver( # noqa: D def dataflow_plan_builder( # noqa: D simple_semantic_manifest_lookup: SemanticManifestLookup, consistent_id_object_repository: ConsistentIdObjectRepository, -) -> DataflowPlanBuilder[SemanticModelDataSet]: +) -> DataflowPlanBuilder: return DataflowPlanBuilder( source_nodes=consistent_id_object_repository.simple_model_source_nodes, semantic_manifest_lookup=simple_semantic_manifest_lookup, - cost_function=DefaultCostFunction[SemanticModelDataSet](), + cost_function=DefaultCostFunction(), ) @@ -45,11 +44,11 @@ def multihop_dataflow_plan_builder( # noqa: D multi_hop_join_semantic_manifest_lookup: SemanticManifestLookup, consistent_id_object_repository: ConsistentIdObjectRepository, time_spine_source: TimeSpineSource, -) -> DataflowPlanBuilder[SemanticModelDataSet]: +) -> DataflowPlanBuilder: return DataflowPlanBuilder( source_nodes=consistent_id_object_repository.multihop_model_source_nodes, semantic_manifest_lookup=multi_hop_join_semantic_manifest_lookup, - cost_function=DefaultCostFunction[SemanticModelDataSet](), + cost_function=DefaultCostFunction(), ) @@ -66,11 +65,11 @@ def scd_dataflow_plan_builder( # noqa: D scd_column_association_resolver: ColumnAssociationResolver, consistent_id_object_repository: ConsistentIdObjectRepository, time_spine_source: TimeSpineSource, -) -> DataflowPlanBuilder[SemanticModelDataSet]: +) -> DataflowPlanBuilder: return DataflowPlanBuilder( source_nodes=consistent_id_object_repository.scd_model_source_nodes, semantic_manifest_lookup=scd_semantic_manifest_lookup, - cost_function=DefaultCostFunction[SemanticModelDataSet](), + cost_function=DefaultCostFunction(), column_association_resolver=scd_column_association_resolver, ) diff --git a/metricflow/test/fixtures/model_fixtures.py b/metricflow/test/fixtures/model_fixtures.py index b6a82b4a05..6d7fcea2ce 100644 --- a/metricflow/test/fixtures/model_fixtures.py +++ b/metricflow/test/fixtures/model_fixtures.py @@ -32,13 +32,11 @@ logger = logging.getLogger(__name__) -def _data_set_to_read_nodes( - data_sets: OrderedDict[str, SemanticModelDataSet] -) -> OrderedDict[str, ReadSqlSourceNode[SemanticModelDataSet]]: +def _data_set_to_read_nodes(data_sets: OrderedDict[str, SemanticModelDataSet]) -> OrderedDict[str, ReadSqlSourceNode]: """Return a mapping from the name of the semantic model to the dataflow plan node that reads from it.""" - return_dict: OrderedDict[str, ReadSqlSourceNode[SemanticModelDataSet]] = OrderedDict() + return_dict: OrderedDict[str, ReadSqlSourceNode] = OrderedDict() for semantic_model_name, data_set in data_sets.items(): - return_dict[semantic_model_name] = ReadSqlSourceNode[SemanticModelDataSet](data_set) + return_dict[semantic_model_name] = ReadSqlSourceNode(data_set) logger.debug( f"For semantic model {semantic_model_name}, creating node_id {return_dict[semantic_model_name].node_id}" ) @@ -48,7 +46,7 @@ def _data_set_to_read_nodes( def _data_set_to_source_nodes( semantic_manifest_lookup: SemanticManifestLookup, data_sets: OrderedDict[str, SemanticModelDataSet] -) -> Sequence[BaseOutput[SemanticModelDataSet]]: +) -> Sequence[BaseOutput]: source_node_builder = SourceNodeBuilder(semantic_manifest_lookup) return source_node_builder.create_from_data_sets(list(data_sets.values())) @@ -78,17 +76,17 @@ class ConsistentIdObjectRepository: """Stores all objects that should have consistent IDs in tests.""" simple_model_data_sets: OrderedDict[str, SemanticModelDataSet] - simple_model_read_nodes: OrderedDict[str, ReadSqlSourceNode[SemanticModelDataSet]] - simple_model_source_nodes: Sequence[BaseOutput[SemanticModelDataSet]] + simple_model_read_nodes: OrderedDict[str, ReadSqlSourceNode] + simple_model_source_nodes: Sequence[BaseOutput] - multihop_model_read_nodes: OrderedDict[str, ReadSqlSourceNode[SemanticModelDataSet]] - multihop_model_source_nodes: Sequence[BaseOutput[SemanticModelDataSet]] + multihop_model_read_nodes: OrderedDict[str, ReadSqlSourceNode] + multihop_model_source_nodes: Sequence[BaseOutput] scd_model_data_sets: OrderedDict[str, SemanticModelDataSet] - scd_model_read_nodes: OrderedDict[str, ReadSqlSourceNode[SemanticModelDataSet]] - scd_model_source_nodes: Sequence[BaseOutput[SemanticModelDataSet]] + scd_model_read_nodes: OrderedDict[str, ReadSqlSourceNode] + scd_model_source_nodes: Sequence[BaseOutput] - cyclic_join_source_nodes: Sequence[BaseOutput[SemanticModelDataSet]] + cyclic_join_source_nodes: Sequence[BaseOutput] @pytest.fixture(scope="session") diff --git a/metricflow/test/plan_conversion/conftest.py b/metricflow/test/plan_conversion/conftest.py index c87f8aaf51..b917d0caae 100644 --- a/metricflow/test/plan_conversion/conftest.py +++ b/metricflow/test/plan_conversion/conftest.py @@ -2,7 +2,6 @@ import pytest -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter @@ -11,8 +10,8 @@ @pytest.fixture(scope="session") def dataflow_to_sql_converter( # noqa: D simple_semantic_manifest_lookup: SemanticManifestLookup, -) -> DataflowToSqlQueryPlanConverter[SemanticModelDataSet]: - return DataflowToSqlQueryPlanConverter[SemanticModelDataSet]( +) -> DataflowToSqlQueryPlanConverter: + return DataflowToSqlQueryPlanConverter( column_association_resolver=DunderColumnAssociationResolver(simple_semantic_manifest_lookup), semantic_manifest_lookup=simple_semantic_manifest_lookup, ) diff --git a/metricflow/test/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py b/metricflow/test/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py index 37af737035..f9710ee87d 100644 --- a/metricflow/test/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py +++ b/metricflow/test/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py @@ -5,7 +5,6 @@ from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder from metricflow.dataflow.dataflow_plan import MetricTimeDimensionTransformNode -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter from metricflow.protocols.sql_client import SqlClient from metricflow.specs.specs import MetricFlowQuerySpec, MetricSpec @@ -18,7 +17,7 @@ def test_metric_time_dimension_transform_node_using_primary_time( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -39,7 +38,7 @@ def test_metric_time_dimension_transform_node_using_primary_time( # noqa: D def test_metric_time_dimension_transform_node_using_non_primary_time( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -61,7 +60,7 @@ def test_metric_time_dimension_transform_node_using_non_primary_time( # noqa: D def test_simple_query_with_metric_time_dimension( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, dataflow_plan_builder: DataflowPlanBuilder, diff --git a/metricflow/test/plan_conversion/test_dataflow_to_execution.py b/metricflow/test/plan_conversion/test_dataflow_to_execution.py index cf98bed775..3a7c489ed8 100644 --- a/metricflow/test/plan_conversion/test_dataflow_to_execution.py +++ b/metricflow/test/plan_conversion/test_dataflow_to_execution.py @@ -3,7 +3,6 @@ from _pytest.fixtures import FixtureRequest from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver from metricflow.plan_conversion.dataflow_to_execution import DataflowToExecutionPlanConverter @@ -25,8 +24,8 @@ def make_execution_plan_converter( # noqa: D semantic_manifest_lookup: SemanticManifestLookup, sql_client: SqlClient, ) -> DataflowToExecutionPlanConverter: - return DataflowToExecutionPlanConverter[SemanticModelDataSet]( - sql_plan_converter=DataflowToSqlQueryPlanConverter[SemanticModelDataSet]( + return DataflowToExecutionPlanConverter( + sql_plan_converter=DataflowToSqlQueryPlanConverter( column_association_resolver=DunderColumnAssociationResolver(semantic_manifest_lookup), semantic_manifest_lookup=semantic_manifest_lookup, ), @@ -38,7 +37,7 @@ def make_execution_plan_converter( # noqa: D def test_joined_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, simple_semantic_manifest_lookup: SemanticManifestLookup, sql_client: SqlClient, ) -> None: @@ -77,7 +76,7 @@ def test_small_combined_metrics_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, sql_client: SqlClient, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, simple_semantic_manifest_lookup: SemanticManifestLookup, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -113,7 +112,7 @@ def test_combined_metrics_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, sql_client: SqlClient, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, simple_semantic_manifest_lookup: SemanticManifestLookup, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -150,7 +149,7 @@ def test_combined_metrics_plan( # noqa: D def test_multihop_joined_plan( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - multihop_dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], + multihop_dataflow_plan_builder: DataflowPlanBuilder, multi_hop_join_semantic_manifest_lookup: SemanticManifestLookup, sql_client: SqlClient, ) -> None: @@ -171,7 +170,7 @@ def test_multihop_joined_plan( # noqa: D ) to_execution_plan_converter = DataflowToExecutionPlanConverter( - sql_plan_converter=DataflowToSqlQueryPlanConverter[SemanticModelDataSet]( + sql_plan_converter=DataflowToSqlQueryPlanConverter( column_association_resolver=DunderColumnAssociationResolver(multi_hop_join_semantic_manifest_lookup), semantic_manifest_lookup=multi_hop_join_semantic_manifest_lookup, ), diff --git a/metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py b/metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py index 0bb3ded01d..a387521655 100644 --- a/metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py +++ b/metricflow/test/plan_conversion/test_dataflow_to_sql_plan.py @@ -30,7 +30,6 @@ ) from metricflow.dataflow.dataflow_plan_to_text import dataflow_plan_as_text from metricflow.dataset.dataset import DataSet -from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet from metricflow.filters.time_constraint import TimeRangeConstraint from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver @@ -62,8 +61,8 @@ @pytest.fixture(scope="session") def multihop_dataflow_to_sql_converter( # noqa: D multi_hop_join_semantic_manifest_lookup: SemanticManifestLookup, -) -> DataflowToSqlQueryPlanConverter[SemanticModelDataSet]: - return DataflowToSqlQueryPlanConverter[SemanticModelDataSet]( +) -> DataflowToSqlQueryPlanConverter: + return DataflowToSqlQueryPlanConverter( column_association_resolver=DunderColumnAssociationResolver(multi_hop_join_semantic_manifest_lookup), semantic_manifest_lookup=multi_hop_join_semantic_manifest_lookup, ) @@ -72,8 +71,8 @@ def multihop_dataflow_to_sql_converter( # noqa: D @pytest.fixture(scope="session") def scd_dataflow_to_sql_converter( # noqa: D scd_semantic_manifest_lookup: SemanticManifestLookup, -) -> DataflowToSqlQueryPlanConverter[SemanticModelDataSet]: - return DataflowToSqlQueryPlanConverter[SemanticModelDataSet]( +) -> DataflowToSqlQueryPlanConverter: + return DataflowToSqlQueryPlanConverter( column_association_resolver=DunderColumnAssociationResolver(scd_semantic_manifest_lookup), semantic_manifest_lookup=scd_semantic_manifest_lookup, ) @@ -82,9 +81,9 @@ def scd_dataflow_to_sql_converter( # noqa: D def convert_and_check( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, - node: BaseOutput[SemanticModelDataSet], + node: BaseOutput, ) -> None: """Convert the dataflow plan to SQL and compare with snapshots.""" # Generate plans w/o optimizers @@ -139,7 +138,7 @@ def convert_and_check( def test_source_node( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -158,7 +157,7 @@ def test_source_node( # noqa: D def test_filter_node( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -167,7 +166,7 @@ def test_filter_node( # noqa: D element_name="bookings", ) source_node = consistent_id_object_repository.simple_model_read_nodes["bookings_source"] - filter_node = FilterElementsNode[SemanticModelDataSet]( + filter_node = FilterElementsNode( parent_node=source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,)) ) @@ -184,7 +183,7 @@ def test_filter_with_where_constraint_node( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, column_association_resolver: ColumnAssociationResolver, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -195,11 +194,11 @@ def test_filter_with_where_constraint_node( # noqa: D source_node = consistent_id_object_repository.simple_model_read_nodes["bookings_source"] ds_spec = TimeDimensionSpec(element_name="ds", entity_links=(), time_granularity=TimeGranularity.DAY) - filter_node = FilterElementsNode[SemanticModelDataSet]( + filter_node = FilterElementsNode( parent_node=source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,), time_dimension_specs=(ds_spec,)), ) # need to include ds_spec because where constraint operates on ds - where_constraint_node = WhereConstraintNode[SemanticModelDataSet]( + where_constraint_node = WhereConstraintNode( parent_node=filter_node, where_constraint=( WhereSpecFactory( @@ -224,7 +223,7 @@ def test_filter_with_where_constraint_node( # noqa: D def test_measure_aggregation_node( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -248,12 +247,12 @@ def test_measure_aggregation_node( # noqa: D metric_input_measure_specs = tuple(MetricInputMeasureSpec(measure_spec=x) for x in measure_specs) measure_source_node = consistent_id_object_repository.simple_model_read_nodes["bookings_source"] - filtered_measure_node = FilterElementsNode[SemanticModelDataSet]( + filtered_measure_node = FilterElementsNode( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=tuple(measure_specs)), ) - aggregated_measure_node = AggregateMeasuresNode[SemanticModelDataSet]( + aggregated_measure_node = AggregateMeasuresNode( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) @@ -269,7 +268,7 @@ def test_measure_aggregation_node( # noqa: D def test_single_join_node( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -279,7 +278,7 @@ def test_single_join_node( # noqa: D ) entity_spec = LinklessEntitySpec.from_element_name(element_name="listing") measure_source_node = consistent_id_object_repository.simple_model_read_nodes["bookings_source"] - filtered_measure_node = FilterElementsNode[SemanticModelDataSet]( + filtered_measure_node = FilterElementsNode( parent_node=measure_source_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), @@ -292,7 +291,7 @@ def test_single_join_node( # noqa: D entity_links=(EntityReference("listing"),), ) dimension_source_node = consistent_id_object_repository.simple_model_read_nodes["listings_latest"] - filtered_dimension_node = FilterElementsNode[SemanticModelDataSet]( + filtered_dimension_node = FilterElementsNode( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -300,7 +299,7 @@ def test_single_join_node( # noqa: D ), ) - join_node = JoinToBaseOutputNode[SemanticModelDataSet]( + join_node = JoinToBaseOutputNode( left_node=filtered_measure_node, join_targets=[ JoinDescription( @@ -324,7 +323,7 @@ def test_single_join_node( # noqa: D def test_multi_join_node( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -334,7 +333,7 @@ def test_multi_join_node( ) entity_spec = LinklessEntitySpec.from_element_name(element_name="listing") measure_source_node = consistent_id_object_repository.simple_model_read_nodes["bookings_source"] - filtered_measure_node = FilterElementsNode[SemanticModelDataSet]( + filtered_measure_node = FilterElementsNode( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,), entity_specs=(entity_spec,)), ) @@ -344,7 +343,7 @@ def test_multi_join_node( entity_links=(), ) dimension_source_node = consistent_id_object_repository.simple_model_read_nodes["listings_latest"] - filtered_dimension_node = FilterElementsNode[SemanticModelDataSet]( + filtered_dimension_node = FilterElementsNode( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -352,7 +351,7 @@ def test_multi_join_node( ), ) - join_node = JoinToBaseOutputNode[SemanticModelDataSet]( + join_node = JoinToBaseOutputNode( left_node=filtered_measure_node, join_targets=[ JoinDescription( @@ -382,7 +381,7 @@ def test_multi_join_node( def test_compute_metrics_node( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -393,7 +392,7 @@ def test_compute_metrics_node( entity_spec = LinklessEntitySpec.from_element_name(element_name="listing") metric_input_measure_specs = (MetricInputMeasureSpec(measure_spec=measure_spec),) measure_source_node = consistent_id_object_repository.simple_model_read_nodes["bookings_source"] - filtered_measure_node = FilterElementsNode[SemanticModelDataSet]( + filtered_measure_node = FilterElementsNode( parent_node=measure_source_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), @@ -406,7 +405,7 @@ def test_compute_metrics_node( entity_links=(), ) dimension_source_node = consistent_id_object_repository.simple_model_read_nodes["listings_latest"] - filtered_dimension_node = FilterElementsNode[SemanticModelDataSet]( + filtered_dimension_node = FilterElementsNode( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -414,7 +413,7 @@ def test_compute_metrics_node( ), ) - join_node = JoinToBaseOutputNode[SemanticModelDataSet]( + join_node = JoinToBaseOutputNode( left_node=filtered_measure_node, join_targets=[ JoinDescription( @@ -426,14 +425,12 @@ def test_compute_metrics_node( ], ) - aggregated_measure_node = AggregateMeasuresNode[SemanticModelDataSet]( + aggregated_measure_node = AggregateMeasuresNode( parent_node=join_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="bookings") - compute_metrics_node = ComputeMetricsNode[SemanticModelDataSet]( - parent_node=aggregated_measure_node, metric_specs=[metric_spec] - ) + compute_metrics_node = ComputeMetricsNode(parent_node=aggregated_measure_node, metric_specs=[metric_spec]) convert_and_check( request=request, @@ -447,7 +444,7 @@ def test_compute_metrics_node( def test_compute_metrics_node_simple_expr( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -458,7 +455,7 @@ def test_compute_metrics_node_simple_expr( entity_spec = LinklessEntitySpec.from_element_name(element_name="listing") metric_input_measure_specs = (MetricInputMeasureSpec(measure_spec=measure_spec),) measure_source_node = consistent_id_object_repository.simple_model_read_nodes["bookings_source"] - filtered_measure_node = FilterElementsNode[SemanticModelDataSet]( + filtered_measure_node = FilterElementsNode( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,), entity_specs=(entity_spec,)), ) @@ -468,7 +465,7 @@ def test_compute_metrics_node_simple_expr( entity_links=(), ) dimension_source_node = consistent_id_object_repository.simple_model_read_nodes["listings_latest"] - filtered_dimension_node = FilterElementsNode[SemanticModelDataSet]( + filtered_dimension_node = FilterElementsNode( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -476,7 +473,7 @@ def test_compute_metrics_node_simple_expr( ), ) - join_node = JoinToBaseOutputNode[SemanticModelDataSet]( + join_node = JoinToBaseOutputNode( left_node=filtered_measure_node, join_targets=[ JoinDescription( @@ -488,15 +485,13 @@ def test_compute_metrics_node_simple_expr( ], ) - aggregated_measures_node = AggregateMeasuresNode[SemanticModelDataSet]( + aggregated_measures_node = AggregateMeasuresNode( parent_node=join_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="booking_fees") - compute_metrics_node = ComputeMetricsNode[SemanticModelDataSet]( - parent_node=aggregated_measures_node, metric_specs=[metric_spec] - ) + compute_metrics_node = ComputeMetricsNode(parent_node=aggregated_measures_node, metric_specs=[metric_spec]) - sink_node = WriteToResultDataframeNode[SemanticModelDataSet](compute_metrics_node) + sink_node = WriteToResultDataframeNode(compute_metrics_node) dataflow_plan = DataflowPlan("plan0", sink_output_nodes=[sink_node]) assert_plan_snapshot_text_equal( @@ -524,7 +519,7 @@ def test_compute_metrics_node_simple_expr( def test_join_to_time_spine_node_without_offset( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -541,19 +536,17 @@ def test_join_to_time_spine_node_without_offset( # noqa: D aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) - filtered_measure_node = FilterElementsNode[SemanticModelDataSet]( + filtered_measure_node = FilterElementsNode( parent_node=metric_time_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), entity_specs=(entity_spec,), dimension_specs=(metric_time_spec,) ), ) - aggregated_measures_node = AggregateMeasuresNode[SemanticModelDataSet]( + aggregated_measures_node = AggregateMeasuresNode( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="booking_fees") - compute_metrics_node = ComputeMetricsNode[SemanticModelDataSet]( - parent_node=aggregated_measures_node, metric_specs=[metric_spec] - ) + compute_metrics_node = ComputeMetricsNode(parent_node=aggregated_measures_node, metric_specs=[metric_spec]) join_to_time_spine_node = JoinToTimeSpineNode( parent_node=compute_metrics_node, metric_time_dimension_specs=[MTD_SPEC_DAY], @@ -561,7 +554,7 @@ def test_join_to_time_spine_node_without_offset( # noqa: D start_time=as_datetime("2020-01-01"), end_time=as_datetime("2021-01-01") ), ) - sink_node = WriteToResultDataframeNode[SemanticModelDataSet](join_to_time_spine_node) + sink_node = WriteToResultDataframeNode(join_to_time_spine_node) dataflow_plan = DataflowPlan("plan0", sink_output_nodes=[sink_node]) assert_plan_snapshot_text_equal( @@ -589,7 +582,7 @@ def test_join_to_time_spine_node_without_offset( # noqa: D def test_join_to_time_spine_node_with_offset_window( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -605,19 +598,17 @@ def test_join_to_time_spine_node_with_offset_window( # noqa: D parent_node=measure_source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) - filtered_measure_node = FilterElementsNode[SemanticModelDataSet]( + filtered_measure_node = FilterElementsNode( parent_node=metric_time_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), entity_specs=(entity_spec,), dimension_specs=(metric_time_spec,) ), ) - aggregated_measures_node = AggregateMeasuresNode[SemanticModelDataSet]( + aggregated_measures_node = AggregateMeasuresNode( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="booking_fees") - compute_metrics_node = ComputeMetricsNode[SemanticModelDataSet]( - parent_node=aggregated_measures_node, metric_specs=[metric_spec] - ) + compute_metrics_node = ComputeMetricsNode(parent_node=aggregated_measures_node, metric_specs=[metric_spec]) join_to_time_spine_node = JoinToTimeSpineNode( parent_node=compute_metrics_node, metric_time_dimension_specs=[MTD_SPEC_DAY], @@ -627,7 +618,7 @@ def test_join_to_time_spine_node_with_offset_window( # noqa: D offset_window=PydanticMetricTimeWindow(count=10, granularity=TimeGranularity.DAY), ) - sink_node = WriteToResultDataframeNode[SemanticModelDataSet](join_to_time_spine_node) + sink_node = WriteToResultDataframeNode(join_to_time_spine_node) dataflow_plan = DataflowPlan("plan0", sink_output_nodes=[sink_node]) assert_plan_snapshot_text_equal( @@ -655,7 +646,7 @@ def test_join_to_time_spine_node_with_offset_window( # noqa: D def test_join_to_time_spine_node_with_offset_to_grain( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -671,19 +662,17 @@ def test_join_to_time_spine_node_with_offset_to_grain( parent_node=measure_source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) - filtered_measure_node = FilterElementsNode[SemanticModelDataSet]( + filtered_measure_node = FilterElementsNode( parent_node=metric_time_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), entity_specs=(entity_spec,), dimension_specs=(metric_time_spec,) ), ) - aggregated_measures_node = AggregateMeasuresNode[SemanticModelDataSet]( + aggregated_measures_node = AggregateMeasuresNode( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="booking_fees") - compute_metrics_node = ComputeMetricsNode[SemanticModelDataSet]( - parent_node=aggregated_measures_node, metric_specs=[metric_spec] - ) + compute_metrics_node = ComputeMetricsNode(parent_node=aggregated_measures_node, metric_specs=[metric_spec]) join_to_time_spine_node = JoinToTimeSpineNode( parent_node=compute_metrics_node, metric_time_dimension_specs=[MTD_SPEC_DAY], @@ -694,7 +683,7 @@ def test_join_to_time_spine_node_with_offset_to_grain( offset_to_grain=TimeGranularity.MONTH, ) - sink_node = WriteToResultDataframeNode[SemanticModelDataSet](join_to_time_spine_node) + sink_node = WriteToResultDataframeNode(join_to_time_spine_node) dataflow_plan = DataflowPlan("plan0", sink_output_nodes=[sink_node]) assert_plan_snapshot_text_equal( @@ -722,7 +711,7 @@ def test_join_to_time_spine_node_with_offset_to_grain( def test_compute_metrics_node_ratio_from_single_semantic_model( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -739,7 +728,7 @@ def test_compute_metrics_node_ratio_from_single_semantic_model( MetricInputMeasureSpec(measure_spec=denominator_spec), ) measure_source_node = consistent_id_object_repository.simple_model_read_nodes["bookings_source"] - filtered_measures_node = FilterElementsNode[SemanticModelDataSet]( + filtered_measures_node = FilterElementsNode( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=(numerator_spec, denominator_spec), entity_specs=(entity_spec,)), ) @@ -749,7 +738,7 @@ def test_compute_metrics_node_ratio_from_single_semantic_model( entity_links=(), ) dimension_source_node = consistent_id_object_repository.simple_model_read_nodes["listings_latest"] - filtered_dimension_node = FilterElementsNode[SemanticModelDataSet]( + filtered_dimension_node = FilterElementsNode( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -757,7 +746,7 @@ def test_compute_metrics_node_ratio_from_single_semantic_model( ), ) - join_node = JoinToBaseOutputNode[SemanticModelDataSet]( + join_node = JoinToBaseOutputNode( left_node=filtered_measures_node, join_targets=[ JoinDescription( @@ -769,13 +758,11 @@ def test_compute_metrics_node_ratio_from_single_semantic_model( ], ) - aggregated_measures_node = AggregateMeasuresNode[SemanticModelDataSet]( + aggregated_measures_node = AggregateMeasuresNode( parent_node=join_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="bookings_per_booker") - compute_metrics_node = ComputeMetricsNode[SemanticModelDataSet]( - parent_node=aggregated_measures_node, metric_specs=[metric_spec] - ) + compute_metrics_node = ComputeMetricsNode(parent_node=aggregated_measures_node, metric_specs=[metric_spec]) convert_and_check( request=request, @@ -790,7 +777,7 @@ def test_compute_metrics_node_ratio_from_multiple_semantic_models( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, dataflow_plan_builder: DataflowPlanBuilder, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: """Tests the compute metrics node for ratio type metrics. @@ -829,7 +816,7 @@ def test_order_by_node( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, consistent_id_object_repository: ConsistentIdObjectRepository, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: """Tests converting a dataflow plan to a SQL query plan where there is a leaf compute metrics node.""" @@ -849,7 +836,7 @@ def test_order_by_node( ) measure_source_node = consistent_id_object_repository.simple_model_read_nodes["bookings_source"] - filtered_measure_node = FilterElementsNode[SemanticModelDataSet]( + filtered_measure_node = FilterElementsNode( parent_node=measure_source_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), @@ -858,14 +845,12 @@ def test_order_by_node( ), ) - aggregated_measure_node = AggregateMeasuresNode[SemanticModelDataSet]( + aggregated_measure_node = AggregateMeasuresNode( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="bookings") - compute_metrics_node = ComputeMetricsNode[SemanticModelDataSet]( - parent_node=aggregated_measure_node, metric_specs=[metric_spec] - ) + compute_metrics_node = ComputeMetricsNode(parent_node=aggregated_measure_node, metric_specs=[metric_spec]) order_by_node = OrderByLimitNode( order_by_specs=[ @@ -893,8 +878,8 @@ def test_order_by_node( def test_multihop_node( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - multihop_dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - multihop_dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + multihop_dataflow_plan_builder: DataflowPlanBuilder, + multihop_dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: """Tests converting a dataflow plan to a SQL query plan where there is a join between 1 measure and 2 dimensions.""" @@ -926,8 +911,8 @@ def test_filter_with_where_constraint_on_join_dim( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, column_association_resolver: ColumnAssociationResolver, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -966,7 +951,7 @@ def test_constrain_time_range_node( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, consistent_id_object_repository: ConsistentIdObjectRepository, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: """Tests converting the ConstrainTimeRangeNode to SQL.""" @@ -989,7 +974,7 @@ def test_constrain_time_range_node( aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) - constrain_time_node = ConstrainTimeRangeNode[SemanticModelDataSet]( + constrain_time_node = ConstrainTimeRangeNode( parent_node=metric_time_node, time_range_constraint=TimeRangeConstraint( start_time=as_datetime("2020-01-01"), @@ -1009,8 +994,8 @@ def test_constrain_time_range_node( def test_cumulative_metric( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -1041,8 +1026,8 @@ def test_cumulative_metric( def test_cumulative_metric_with_time_constraint( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -1076,8 +1061,8 @@ def test_cumulative_metric_with_time_constraint( def test_cumulative_metric_no_ds( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -1102,8 +1087,8 @@ def test_cumulative_metric_no_ds( def test_cumulative_metric_no_window( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -1134,8 +1119,8 @@ def test_cumulative_metric_no_window( def test_cumulative_metric_no_window_with_time_constraint( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -1169,8 +1154,8 @@ def test_cumulative_metric_no_window_with_time_constraint( def test_cumulative_metric_grain_to_date( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -1201,8 +1186,8 @@ def test_cumulative_metric_grain_to_date( def test_partitioned_join( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, consistent_id_object_repository: ConsistentIdObjectRepository, sql_client: SqlClient, ) -> None: @@ -1231,8 +1216,8 @@ def test_partitioned_join( def test_limit_rows( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: """Tests a plan with a limit to the number of rows returned.""" @@ -1261,8 +1246,8 @@ def test_limit_rows( # noqa: D def test_distinct_values( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: """Tests a plan to get distinct values for a dimension.""" @@ -1287,8 +1272,8 @@ def test_distinct_values( # noqa: D def test_local_dimension_using_local_entity( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1316,7 +1301,7 @@ def test_semi_additive_join_node( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, consistent_id_object_repository: ConsistentIdObjectRepository, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: """Tests converting a dataflow plan to a SQL query plan using a SemiAdditiveJoinNode.""" @@ -1324,7 +1309,7 @@ def test_semi_additive_join_node( time_dimension_spec = TimeDimensionSpec(element_name="ds", entity_links=()) measure_source_node = consistent_id_object_repository.simple_model_read_nodes["accounts_source"] - semi_additive_join_node = SemiAdditiveJoinNode[SemanticModelDataSet]( + semi_additive_join_node = SemiAdditiveJoinNode( parent_node=measure_source_node, entity_specs=tuple(), time_dimension_spec=time_dimension_spec, @@ -1344,7 +1329,7 @@ def test_semi_additive_join_node_with_queried_group_by( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, consistent_id_object_repository: ConsistentIdObjectRepository, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: """Tests converting a dataflow plan to a SQL query plan using a SemiAdditiveJoinNode.""" @@ -1355,7 +1340,7 @@ def test_semi_additive_join_node_with_queried_group_by( ) measure_source_node = consistent_id_object_repository.simple_model_read_nodes["accounts_source"] - semi_additive_join_node = SemiAdditiveJoinNode[SemanticModelDataSet]( + semi_additive_join_node = SemiAdditiveJoinNode( parent_node=measure_source_node, entity_specs=tuple(), time_dimension_spec=time_dimension_spec, @@ -1375,7 +1360,7 @@ def test_semi_additive_join_node_with_grouping( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, consistent_id_object_repository: ConsistentIdObjectRepository, - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: """Tests converting a dataflow plan to a SQL query plan using a SemiAdditiveJoinNode with a window_grouping.""" @@ -1388,7 +1373,7 @@ def test_semi_additive_join_node_with_grouping( time_dimension_spec = TimeDimensionSpec(element_name="ds", entity_links=()) measure_source_node = consistent_id_object_repository.simple_model_read_nodes["accounts_source"] - semi_additive_join_node = SemiAdditiveJoinNode[SemanticModelDataSet]( + semi_additive_join_node = SemiAdditiveJoinNode( parent_node=measure_source_node, entity_specs=(entity_spec,), time_dimension_spec=time_dimension_spec, @@ -1406,8 +1391,8 @@ def test_semi_additive_join_node_with_grouping( def test_measure_constraint( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1429,8 +1414,8 @@ def test_measure_constraint( # noqa: D def test_measure_constraint_with_reused_measure( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1452,8 +1437,8 @@ def test_measure_constraint_with_reused_measure( # noqa: D def test_measure_constraint_with_single_expr_and_alias( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1475,8 +1460,8 @@ def test_measure_constraint_with_single_expr_and_alias( # noqa: D def test_derived_metric( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1498,8 +1483,8 @@ def test_derived_metric( # noqa: D def test_nested_derived_metric( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1522,8 +1507,8 @@ def test_join_to_scd_dimension( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, scd_column_association_resolver: ColumnAssociationResolver, - scd_dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - scd_dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + scd_dataflow_plan_builder: DataflowPlanBuilder, + scd_dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: """Tests conversion of a plan using a dimension with a validity window inside a measure constraint.""" @@ -1559,8 +1544,8 @@ def test_join_to_scd_dimension( def test_multi_hop_through_scd_dimension( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - scd_dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - scd_dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + scd_dataflow_plan_builder: DataflowPlanBuilder, + scd_dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: """Tests conversion of a plan using a dimension that is reached through an SCD table.""" @@ -1584,8 +1569,8 @@ def test_multi_hop_through_scd_dimension( def test_multi_hop_to_scd_dimension( request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - scd_dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - scd_dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + scd_dataflow_plan_builder: DataflowPlanBuilder, + scd_dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: """Tests conversion of a plan using an SCD dimension that is reached through another table.""" @@ -1609,8 +1594,8 @@ def test_multi_hop_to_scd_dimension( def test_multiple_metrics_no_dimensions( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1634,8 +1619,8 @@ def test_multiple_metrics_no_dimensions( # noqa: D def test_metric_with_measures_from_multiple_sources_no_dimensions( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1656,8 +1641,8 @@ def test_metric_with_measures_from_multiple_sources_no_dimensions( # noqa: D def test_common_semantic_model( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1679,8 +1664,8 @@ def test_common_semantic_model( # noqa: D def test_derived_metric_with_offset_window( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1702,8 +1687,8 @@ def test_derived_metric_with_offset_window( # noqa: D def test_derived_metric_with_offset_to_grain( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1725,8 +1710,8 @@ def test_derived_metric_with_offset_to_grain( # noqa: D def test_derived_metric_with_offset_window_and_offset_to_grain( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1748,8 +1733,8 @@ def test_derived_metric_with_offset_window_and_offset_to_grain( # noqa: D def test_derived_offset_metric_with_one_input_metric( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1771,8 +1756,8 @@ def test_derived_offset_metric_with_one_input_metric( # noqa: D def test_derived_metric_with_offset_window_and_granularity( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1794,8 +1779,8 @@ def test_derived_metric_with_offset_window_and_granularity( # noqa: D def test_derived_metric_with_offset_to_grain_and_granularity( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1817,8 +1802,8 @@ def test_derived_metric_with_offset_to_grain_and_granularity( # noqa: D def test_derived_metric_with_offset_window_and_offset_to_grain_and_granularity( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( @@ -1840,8 +1825,8 @@ def test_derived_metric_with_offset_window_and_offset_to_grain_and_granularity( def test_derived_offset_cumulative_metric( # noqa: D request: FixtureRequest, mf_test_session_state: MetricFlowTestSessionState, - dataflow_plan_builder: DataflowPlanBuilder[SemanticModelDataSet], - dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter[SemanticModelDataSet], + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, sql_client: SqlClient, ) -> None: dataflow_plan = dataflow_plan_builder.build_plan( From 90495040fd8f5d031ec685c8c1a6a25a9753bf51 Mon Sep 17 00:00:00 2001 From: tlento Date: Thu, 7 Sep 2023 15:49:37 -0700 Subject: [PATCH 2/3] Remove unused SameSemanticModelReferenceChecker In cleaning up our dataset module organization I noticed this helper class was not used anywhere. This removes it. --- metricflow/plan_conversion/sql_dataset.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/metricflow/plan_conversion/sql_dataset.py b/metricflow/plan_conversion/sql_dataset.py index 27775b95eb..4d054834a1 100644 --- a/metricflow/plan_conversion/sql_dataset.py +++ b/metricflow/plan_conversion/sql_dataset.py @@ -1,15 +1,12 @@ from __future__ import annotations -from typing import List, Sequence +from typing import Sequence import more_itertools -from dbt_semantic_interfaces.references import SemanticModelReference from metricflow.dataset.dataset import DataSet from metricflow.instances import ( InstanceSet, - InstanceSetTransform, - SemanticModelElementInstance, ) from metricflow.specs.column_assoc import ColumnAssociation from metricflow.specs.specs import DimensionSpec, EntitySpec, TimeDimensionSpec @@ -119,19 +116,3 @@ def groupable_column_associations(self) -> Sequence[ColumnAssociation]: + self.instance_set.time_dimension_instances ) return tuple(more_itertools.flatten([instance.associated_columns for instance in instances])) - - -class SameSemanticModelReferenceChecker(InstanceSetTransform[bool]): - """Checks to see that all elements in the instance set come from the same semantic model.""" - - def __init__(self, semantic_model_reference: SemanticModelReference) -> None: # noqa: D - self._semantic_model_reference = semantic_model_reference - - def transform(self, instance_set: InstanceSet) -> bool: # noqa: D - combined: List[SemanticModelElementInstance] = [] - combined.extend(instance_set.measure_instances) - combined.extend(instance_set.dimension_instances) - combined.extend(instance_set.time_dimension_instances) - combined.extend(instance_set.entity_instances) - - return all([all([y.is_from(self._semantic_model_reference) for y in x.defined_from]) for x in combined]) From f20fadfff5b384d957c50eaabf247fb6e3b0c530 Mon Sep 17 00:00:00 2001 From: tlento Date: Thu, 7 Sep 2023 15:52:34 -0700 Subject: [PATCH 3/3] Move SqlDataSet class from plan_conversion to dataset package The other DataSet classes are all in one place, and this is no longer strictly a plan_conversion artifact (indeed, it never was, it just looked that way due to the generics). --- metricflow/dataflow/builder/node_data_set.py | 2 +- metricflow/dataflow/builder/node_evaluator.py | 2 +- metricflow/dataflow/dataflow_plan.py | 2 +- metricflow/dataset/semantic_model_adapter.py | 2 +- metricflow/{plan_conversion => dataset}/sql_dataset.py | 0 metricflow/plan_conversion/dataflow_to_sql.py | 2 +- metricflow/plan_conversion/sql_join_builder.py | 2 +- metricflow/test/dataflow/builder/test_node_data_set.py | 2 +- 8 files changed, 7 insertions(+), 7 deletions(-) rename metricflow/{plan_conversion => dataset}/sql_dataset.py (100%) diff --git a/metricflow/dataflow/builder/node_data_set.py b/metricflow/dataflow/builder/node_data_set.py index 2f1b5b5b5c..45719082c6 100644 --- a/metricflow/dataflow/builder/node_data_set.py +++ b/metricflow/dataflow/builder/node_data_set.py @@ -5,9 +5,9 @@ from metricflow.dataflow.dataflow_plan import ( DataflowPlanNode, ) +from metricflow.dataset.sql_dataset import SqlDataSet from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter -from metricflow.plan_conversion.sql_dataset import SqlDataSet from metricflow.specs.column_assoc import ColumnAssociationResolver diff --git a/metricflow/dataflow/builder/node_evaluator.py b/metricflow/dataflow/builder/node_evaluator.py index 8e09e30c95..5ab12bed9f 100644 --- a/metricflow/dataflow/builder/node_evaluator.py +++ b/metricflow/dataflow/builder/node_evaluator.py @@ -32,10 +32,10 @@ PartitionTimeDimensionJoinDescription, ValidityWindowJoinDescription, ) +from metricflow.dataset.sql_dataset import SqlDataSet from metricflow.instances import InstanceSet from metricflow.model.semantics.semantic_model_join_evaluator import SemanticModelJoinEvaluator from metricflow.plan_conversion.instance_converters import CreateValidityWindowJoinDescription -from metricflow.plan_conversion.sql_dataset import SqlDataSet from metricflow.protocols.semantics import SemanticModelAccessor from metricflow.specs.specs import ( LinkableInstanceSpec, diff --git a/metricflow/dataflow/dataflow_plan.py b/metricflow/dataflow/dataflow_plan.py index a7dde8ac62..89aef6c18c 100644 --- a/metricflow/dataflow/dataflow_plan.py +++ b/metricflow/dataflow/dataflow_plan.py @@ -38,8 +38,8 @@ PartitionTimeDimensionJoinDescription, ) from metricflow.dataflow.sql_table import SqlTable +from metricflow.dataset.sql_dataset import SqlDataSet from metricflow.filters.time_constraint import TimeRangeConstraint -from metricflow.plan_conversion.sql_dataset import SqlDataSet from metricflow.specs.specs import ( InstanceSpecSet, LinklessEntitySpec, diff --git a/metricflow/dataset/semantic_model_adapter.py b/metricflow/dataset/semantic_model_adapter.py index 20d54db2cf..0d857c44c5 100644 --- a/metricflow/dataset/semantic_model_adapter.py +++ b/metricflow/dataset/semantic_model_adapter.py @@ -2,8 +2,8 @@ from dbt_semantic_interfaces.references import SemanticModelReference +from metricflow.dataset.sql_dataset import SqlDataSet from metricflow.instances import InstanceSet -from metricflow.plan_conversion.sql_dataset import SqlDataSet from metricflow.sql.sql_plan import SqlSelectStatementNode diff --git a/metricflow/plan_conversion/sql_dataset.py b/metricflow/dataset/sql_dataset.py similarity index 100% rename from metricflow/plan_conversion/sql_dataset.py rename to metricflow/dataset/sql_dataset.py diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 74d0249f18..27b6efe4e1 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -33,6 +33,7 @@ WriteToResultTableNode, ) from metricflow.dataset.dataset import DataSet +from metricflow.dataset.sql_dataset import SqlDataSet from metricflow.filters.time_constraint import TimeRangeConstraint from metricflow.instances import ( InstanceSet, @@ -62,7 +63,6 @@ CreateSelectCoalescedColumnsForLinkableSpecs, SelectOnlyLinkableSpecs, ) -from metricflow.plan_conversion.sql_dataset import SqlDataSet from metricflow.plan_conversion.sql_join_builder import ( AnnotatedSqlDataSet, ColumnEqualityDescription, diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index 39d06142b5..92ca575721 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -4,7 +4,7 @@ from typing import List, Optional, Sequence, Tuple from metricflow.dataflow.dataflow_plan import JoinDescription, JoinOverTimeRangeNode, JoinToTimeSpineNode -from metricflow.plan_conversion.sql_dataset import SqlDataSet +from metricflow.dataset.sql_dataset import SqlDataSet from metricflow.plan_conversion.sql_expression_builders import make_coalesced_expr from metricflow.sql.sql_exprs import ( SqlColumnReference, diff --git a/metricflow/test/dataflow/builder/test_node_data_set.py b/metricflow/test/dataflow/builder/test_node_data_set.py index 11bf8de279..f9f696211a 100644 --- a/metricflow/test/dataflow/builder/test_node_data_set.py +++ b/metricflow/test/dataflow/builder/test_node_data_set.py @@ -9,13 +9,13 @@ from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver from metricflow.dataflow.dataflow_plan import JoinDescription, JoinToBaseOutputNode, ReadSqlSourceNode from metricflow.dataflow.sql_table import SqlTable +from metricflow.dataset.sql_dataset import SqlDataSet from metricflow.instances import ( InstanceSet, MeasureInstance, ) from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.plan_conversion.column_resolver import DunderColumnAssociationResolver -from metricflow.plan_conversion.sql_dataset import SqlDataSet from metricflow.plan_conversion.time_spine import TimeSpineSource from metricflow.specs.column_assoc import ColumnAssociation, SingleColumnCorrelationKey from metricflow.specs.specs import (