diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index b4ed743b9f..31fc50a14e 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -736,11 +736,8 @@ def _build_aggregated_measure_from_measure_source_node( # Only get the required measure and the local linkable instances so that aggregations work correctly. filtered_measure_source_node = FilterElementsNode( parent_node=join_to_time_spine_node or time_range_node or measure_recipe.source_node, - include_specs=InstanceSpecSet.merge( - ( - InstanceSpecSet(measure_specs=(measure_spec,)), - InstanceSpecSet.create_from_linkable_specs(measure_recipe.required_local_linkable_specs), - ) + include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge( + InstanceSpecSet.create_from_linkable_specs(measure_recipe.required_local_linkable_specs), ), ) @@ -752,11 +749,8 @@ def _build_aggregated_measure_from_measure_source_node( join_targets=join_targets, ) - specs_to_keep_after_join = InstanceSpecSet.merge( - ( - InstanceSpecSet(measure_specs=(measure_spec,)), - required_linkable_specs.as_spec_set, - ) + specs_to_keep_after_join = InstanceSpecSet(measure_specs=(measure_spec,)).merge( + required_linkable_specs.as_spec_set, ) after_join_filtered_node = FilterElementsNode( @@ -814,10 +808,9 @@ def _build_aggregated_measure_from_measure_source_node( # e.g. for "bookings" by "ds" where "is_instant", "is_instant" should not be in the results. pre_aggregate_node = FilterElementsNode( parent_node=pre_aggregate_node, - include_specs=InstanceSpecSet.merge( - (InstanceSpecSet(measure_specs=(measure_spec,)), queried_linkable_specs.as_spec_set) - ), + include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(queried_linkable_specs.as_spec_set), ) + aggregate_measures_node = AggregateMeasuresNode( parent_node=pre_aggregate_node, metric_input_measure_specs=(metric_input_measure_spec,), diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index 2155caeb3e..d14e7b1411 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -25,7 +25,6 @@ WriteToResultTableNode, ) from metricflow.dataflow.optimizer.source_scan.matching_linkable_specs import MatchingLinkableSpecsTransform -from metricflow.specs.specs import InstanceSpecSet logger = logging.getLogger(__name__) @@ -360,9 +359,7 @@ def visit_pass_elements_filter_node( # noqa: D # specs since any branch that is merged together needs to output the same set of dimensions. combined_node = FilterElementsNode( parent_node=combined_parent_node, - include_specs=InstanceSpecSet.merge( - (self._current_left_node.include_specs, current_right_node.include_specs) - ).dedupe(), + include_specs=self._current_left_node.include_specs.merge(current_right_node.include_specs).dedupe(), ) self._log_combine_success( left_node=self._current_left_node, diff --git a/metricflow/specs/specs.py b/metricflow/specs/specs.py index dc1683dd39..618b6f81bd 100644 --- a/metricflow/specs/specs.py +++ b/metricflow/specs/specs.py @@ -28,8 +28,10 @@ from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity +from typing_extensions import override from metricflow.aggregation_properties import AggregationState +from metricflow.collections.merger import Mergeable from metricflow.filters.time_constraint import TimeRangeConstraint from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName from metricflow.sql.sql_bind_parameters import SqlBindParameters @@ -582,7 +584,7 @@ def transform(self, spec_set: InstanceSpecSet) -> TransformOutputT: # noqa: D @dataclass(frozen=True) -class InstanceSpecSet(SerializableDataclass): +class InstanceSpecSet(Mergeable, SerializableDataclass): """Consolidates all specs used in an instance set.""" metric_specs: Tuple[MetricSpec, ...] = () @@ -592,18 +594,22 @@ class InstanceSpecSet(SerializableDataclass): time_dimension_specs: Tuple[TimeDimensionSpec, ...] = () metadata_specs: Tuple[MetadataSpec, ...] = () - @staticmethod - def merge(others: Sequence[InstanceSpecSet]) -> InstanceSpecSet: - """Merge all sets into one set, without de-duplication.""" + @override + def merge(self, other: InstanceSpecSet) -> InstanceSpecSet: return InstanceSpecSet( - metric_specs=tuple(itertools.chain.from_iterable([x.metric_specs for x in others])), - measure_specs=tuple(itertools.chain.from_iterable([x.measure_specs for x in others])), - dimension_specs=tuple(itertools.chain.from_iterable([x.dimension_specs for x in others])), - entity_specs=tuple(itertools.chain.from_iterable([x.entity_specs for x in others])), - time_dimension_specs=tuple(itertools.chain.from_iterable([x.time_dimension_specs for x in others])), - metadata_specs=tuple(itertools.chain.from_iterable([x.metadata_specs for x in others])), + metric_specs=self.metric_specs + other.metric_specs, + measure_specs=self.measure_specs + other.measure_specs, + dimension_specs=self.dimension_specs + other.dimension_specs, + entity_specs=self.entity_specs + other.entity_specs, + time_dimension_specs=self.time_dimension_specs + other.time_dimension_specs, + metadata_specs=self.metadata_specs + other.metadata_specs, ) + @override + @classmethod + def empty_instance(cls) -> InstanceSpecSet: + return InstanceSpecSet() + def dedupe(self) -> InstanceSpecSet: """De-duplicates repeated elements. @@ -665,7 +671,9 @@ def transform(self, transform_function: InstanceSpecSetTransform[TransformOutput @staticmethod def create_from_linkable_specs(linkable_specs: Sequence[LinkableInstanceSpec]) -> InstanceSpecSet: # noqa: D - return InstanceSpecSet.merge(tuple(x.as_linkable_spec_set.as_spec_set for x in linkable_specs)) + return InstanceSpecSet.merge_iterable( + tuple(linkable_spec.as_linkable_spec_set.as_spec_set for linkable_spec in linkable_specs) + ) @dataclass(frozen=True) diff --git a/metricflow/test/test_specs.py b/metricflow/test/test_specs.py index 172b07fa7a..8b1aca8109 100644 --- a/metricflow/test/test_specs.py +++ b/metricflow/test/test_specs.py @@ -111,7 +111,7 @@ def test_merge_spec_set() -> None: # noqa: D dimension_specs=(DimensionSpec(element_name="is_instant", entity_links=(EntityReference("booking"),)),) ) - assert InstanceSpecSet.merge((spec_set1, spec_set2)) == InstanceSpecSet( + assert spec_set1.merge(spec_set2) == InstanceSpecSet( metric_specs=(MetricSpec(element_name="bookings"),), dimension_specs=(DimensionSpec(element_name="is_instant", entity_links=(EntityReference("booking"),)),), )