diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 31fc50a14e..6eda727816 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -674,17 +674,15 @@ def _build_aggregated_measure_from_measure_source_node( # Extraneous linkable specs are specs that are used in this phase that should not show up in the final result # unless it was already a requested spec in the query - extraneous_linkable_specs = LinkableSpecSet() + linkable_spec_sets_to_merge: List[LinkableSpecSet] = [] if where_constraint: - extraneous_linkable_specs = LinkableSpecSet.merge( - (extraneous_linkable_specs, where_constraint.linkable_spec_set) - ) + linkable_spec_sets_to_merge.append(where_constraint.linkable_spec_set) if non_additive_dimension_spec: - extraneous_linkable_specs = LinkableSpecSet.merge( - (extraneous_linkable_specs, non_additive_dimension_spec.linkable_specs) - ) + linkable_spec_sets_to_merge.append(non_additive_dimension_spec.linkable_specs) + + extraneous_linkable_specs = LinkableSpecSet.merge_iterable(linkable_spec_sets_to_merge).dedupe() + required_linkable_specs = queried_linkable_specs.merge(extraneous_linkable_specs).dedupe() - required_linkable_specs = LinkableSpecSet.merge((queried_linkable_specs, extraneous_linkable_specs)) logger.info( f"Looking for a recipe to get:\n" f"{pformat_big_objects(measure_specs=[measure_spec], required_linkable_set=required_linkable_specs)}" diff --git a/metricflow/specs/specs.py b/metricflow/specs/specs.py index 618b6f81bd..3602c3dbdd 100644 --- a/metricflow/specs/specs.py +++ b/metricflow/specs/specs.py @@ -14,7 +14,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from hashlib import sha1 -from typing import Any, Generic, List, Optional, Sequence, Tuple, TypeVar +from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass from dbt_semantic_interfaces.implementations.metric import PydanticMetricTimeWindow @@ -493,7 +493,7 @@ class FilterSpec(SerializableDataclass): # noqa: D @dataclass(frozen=True) -class LinkableSpecSet(SerializableDataclass): +class LinkableSpecSet(Mergeable, SerializableDataclass): """Groups linkable specs.""" dimension_specs: Tuple[DimensionSpec, ...] = () @@ -504,28 +504,38 @@ class LinkableSpecSet(SerializableDataclass): def as_tuple(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D return tuple(itertools.chain(self.dimension_specs, self.time_dimension_specs, self.entity_specs)) - @staticmethod - def merge(spec_sets: Sequence[LinkableSpecSet]) -> LinkableSpecSet: - """Merges and dedupes the linkable specs.""" - dimension_specs: List[DimensionSpec] = [] - time_dimension_specs: List[TimeDimensionSpec] = [] - entity_specs: List[EntitySpec] = [] - - for spec_set in spec_sets: - for dimension_spec in spec_set.dimension_specs: - if dimension_spec not in dimension_specs: - dimension_specs.append(dimension_spec) - for time_dimension_spec in spec_set.time_dimension_specs: - if time_dimension_spec not in time_dimension_specs: - time_dimension_specs.append(time_dimension_spec) - for entity_spec in spec_set.entity_specs: - if entity_spec not in entity_specs: - entity_specs.append(entity_spec) + @override + def merge(self, other: LinkableSpecSet) -> LinkableSpecSet: + return LinkableSpecSet( + dimension_specs=self.dimension_specs + other.dimension_specs, + time_dimension_specs=self.time_dimension_specs + other.time_dimension_specs, + entity_specs=self.entity_specs + other.entity_specs, + ) + + @override + @classmethod + def empty_instance(cls) -> LinkableSpecSet: + return LinkableSpecSet() + + def dedupe(self) -> LinkableSpecSet: # noqa: D + # Use dictionaries to dedupe as it preserves insertion order. + + dimension_spec_dict: Dict[DimensionSpec, None] = {} + for dimension_spec in self.dimension_specs: + dimension_spec_dict[dimension_spec] = None + + time_dimension_spec_dict: Dict[TimeDimensionSpec, None] = {} + for time_dimension_spec in self.time_dimension_specs: + time_dimension_spec_dict[time_dimension_spec] = None + + entity_spec_dict: Dict[EntitySpec, None] = {} + for entity_spec in self.entity_specs: + entity_spec_dict[entity_spec] = None return LinkableSpecSet( - dimension_specs=tuple(dimension_specs), - time_dimension_specs=tuple(time_dimension_specs), - entity_specs=tuple(entity_specs), + dimension_specs=tuple(dimension_spec_dict.keys()), + time_dimension_specs=tuple(time_dimension_spec_dict.keys()), + entity_specs=tuple(entity_spec_dict.keys()), ) def is_subset_of(self, other_set: LinkableSpecSet) -> bool: # noqa: D @@ -724,5 +734,5 @@ def combine(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D return WhereFilterSpec( where_sql=f"({self.where_sql}) AND ({other.where_sql})", bind_parameters=self.bind_parameters.combine(other.bind_parameters), - linkable_spec_set=LinkableSpecSet.merge([self.linkable_spec_set, other.linkable_spec_set]), + linkable_spec_set=self.linkable_spec_set.merge(other.linkable_spec_set), )