diff --git a/metricflow/collections/__init__.py b/metricflow/collections/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/metricflow/collections/merger.py b/metricflow/collections/merger.py new file mode 100644 index 0000000000..476b8f6c0b --- /dev/null +++ b/metricflow/collections/merger.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import functools +from abc import ABC, abstractmethod +from typing import Iterable, Type, TypeVar + +from typing_extensions import Self + +MergeableT = TypeVar("MergeableT", bound="Mergeable") + + +class Mergeable(ABC): + """Objects that can be merged together to form a superset object of the same type. + + Merging objects are frequently needed in MetricFlow as there are several recursive operations where the output is + the superset of the result of the recursive calls. + + e.g. + * The validation issue set of a derived metric includes the issues of the parent metrics. + * The output of a node in the dataflow plan can include the outputs of the parent nodes. + * The query-time where filter is useful to combine with the metric-defined where filter. + + Having a common interface also gives a consistent name to this operation so that we don't end up with multiple names + to describe this operation (e.g. combine, add, concat). + + This is used to streamline the following case that occurs in the codebase: + + items_to_merge: List[ItemType] = [] + ... + if ... + items_to_merge.append(...) + ... + if ... + items_to_merge.append(...) + ... + if ... + ... + items_to_merge.append(...) + return merge_items(items_to_merge) + ... + return merge_items(items_to_merge) + + This centralizes the definition of the merge_items() call. + """ + + @abstractmethod + def merge(self: Self, other: Self) -> Self: + """Return a new object that is the result of merging self with other.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def empty_instance(cls: Type[MergeableT]) -> MergeableT: + """Create an empty instance to handle merging of empty sequences of items. + + As merge_iterable() returns an empty instance for an empty iterable, there needs to be a way of creating one. + """ + raise NotImplementedError + + @classmethod + def merge_iterable(cls: Type[MergeableT], items: Iterable[MergeableT]) -> MergeableT: + """Merge all items into a single instance. + + If an empty iterable has been passed in, this returns an empty instance. + """ + return functools.reduce(cls.merge, items, cls.empty_instance()) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index b4ed743b9f..6eda727816 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -674,17 +674,15 @@ def _build_aggregated_measure_from_measure_source_node( # Extraneous linkable specs are specs that are used in this phase that should not show up in the final result # unless it was already a requested spec in the query - extraneous_linkable_specs = LinkableSpecSet() + linkable_spec_sets_to_merge: List[LinkableSpecSet] = [] if where_constraint: - extraneous_linkable_specs = LinkableSpecSet.merge( - (extraneous_linkable_specs, where_constraint.linkable_spec_set) - ) + linkable_spec_sets_to_merge.append(where_constraint.linkable_spec_set) if non_additive_dimension_spec: - extraneous_linkable_specs = LinkableSpecSet.merge( - (extraneous_linkable_specs, non_additive_dimension_spec.linkable_specs) - ) + linkable_spec_sets_to_merge.append(non_additive_dimension_spec.linkable_specs) + + extraneous_linkable_specs = LinkableSpecSet.merge_iterable(linkable_spec_sets_to_merge).dedupe() + required_linkable_specs = queried_linkable_specs.merge(extraneous_linkable_specs).dedupe() - required_linkable_specs = LinkableSpecSet.merge((queried_linkable_specs, extraneous_linkable_specs)) logger.info( f"Looking for a recipe to get:\n" f"{pformat_big_objects(measure_specs=[measure_spec], required_linkable_set=required_linkable_specs)}" @@ -736,11 +734,8 @@ def _build_aggregated_measure_from_measure_source_node( # Only get the required measure and the local linkable instances so that aggregations work correctly. filtered_measure_source_node = FilterElementsNode( parent_node=join_to_time_spine_node or time_range_node or measure_recipe.source_node, - include_specs=InstanceSpecSet.merge( - ( - InstanceSpecSet(measure_specs=(measure_spec,)), - InstanceSpecSet.create_from_linkable_specs(measure_recipe.required_local_linkable_specs), - ) + include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge( + InstanceSpecSet.create_from_linkable_specs(measure_recipe.required_local_linkable_specs), ), ) @@ -752,11 +747,8 @@ def _build_aggregated_measure_from_measure_source_node( join_targets=join_targets, ) - specs_to_keep_after_join = InstanceSpecSet.merge( - ( - InstanceSpecSet(measure_specs=(measure_spec,)), - required_linkable_specs.as_spec_set, - ) + specs_to_keep_after_join = InstanceSpecSet(measure_specs=(measure_spec,)).merge( + required_linkable_specs.as_spec_set, ) after_join_filtered_node = FilterElementsNode( @@ -814,10 +806,9 @@ def _build_aggregated_measure_from_measure_source_node( # e.g. for "bookings" by "ds" where "is_instant", "is_instant" should not be in the results. pre_aggregate_node = FilterElementsNode( parent_node=pre_aggregate_node, - include_specs=InstanceSpecSet.merge( - (InstanceSpecSet(measure_specs=(measure_spec,)), queried_linkable_specs.as_spec_set) - ), + include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(queried_linkable_specs.as_spec_set), ) + aggregate_measures_node = AggregateMeasuresNode( parent_node=pre_aggregate_node, metric_input_measure_specs=(metric_input_measure_spec,), diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index 2155caeb3e..d14e7b1411 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -25,7 +25,6 @@ WriteToResultTableNode, ) from metricflow.dataflow.optimizer.source_scan.matching_linkable_specs import MatchingLinkableSpecsTransform -from metricflow.specs.specs import InstanceSpecSet logger = logging.getLogger(__name__) @@ -360,9 +359,7 @@ def visit_pass_elements_filter_node( # noqa: D # specs since any branch that is merged together needs to output the same set of dimensions. combined_node = FilterElementsNode( parent_node=combined_parent_node, - include_specs=InstanceSpecSet.merge( - (self._current_left_node.include_specs, current_right_node.include_specs) - ).dedupe(), + include_specs=self._current_left_node.include_specs.merge(current_right_node.include_specs).dedupe(), ) self._log_combine_success( left_node=self._current_left_node, diff --git a/metricflow/specs/specs.py b/metricflow/specs/specs.py index dc1683dd39..3602c3dbdd 100644 --- a/metricflow/specs/specs.py +++ b/metricflow/specs/specs.py @@ -14,7 +14,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from hashlib import sha1 -from typing import Any, Generic, List, Optional, Sequence, Tuple, TypeVar +from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass from dbt_semantic_interfaces.implementations.metric import PydanticMetricTimeWindow @@ -28,8 +28,10 @@ from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity +from typing_extensions import override from metricflow.aggregation_properties import AggregationState +from metricflow.collections.merger import Mergeable from metricflow.filters.time_constraint import TimeRangeConstraint from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName from metricflow.sql.sql_bind_parameters import SqlBindParameters @@ -491,7 +493,7 @@ class FilterSpec(SerializableDataclass): # noqa: D @dataclass(frozen=True) -class LinkableSpecSet(SerializableDataclass): +class LinkableSpecSet(Mergeable, SerializableDataclass): """Groups linkable specs.""" dimension_specs: Tuple[DimensionSpec, ...] = () @@ -502,28 +504,38 @@ class LinkableSpecSet(SerializableDataclass): def as_tuple(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D return tuple(itertools.chain(self.dimension_specs, self.time_dimension_specs, self.entity_specs)) - @staticmethod - def merge(spec_sets: Sequence[LinkableSpecSet]) -> LinkableSpecSet: - """Merges and dedupes the linkable specs.""" - dimension_specs: List[DimensionSpec] = [] - time_dimension_specs: List[TimeDimensionSpec] = [] - entity_specs: List[EntitySpec] = [] - - for spec_set in spec_sets: - for dimension_spec in spec_set.dimension_specs: - if dimension_spec not in dimension_specs: - dimension_specs.append(dimension_spec) - for time_dimension_spec in spec_set.time_dimension_specs: - if time_dimension_spec not in time_dimension_specs: - time_dimension_specs.append(time_dimension_spec) - for entity_spec in spec_set.entity_specs: - if entity_spec not in entity_specs: - entity_specs.append(entity_spec) + @override + def merge(self, other: LinkableSpecSet) -> LinkableSpecSet: + return LinkableSpecSet( + dimension_specs=self.dimension_specs + other.dimension_specs, + time_dimension_specs=self.time_dimension_specs + other.time_dimension_specs, + entity_specs=self.entity_specs + other.entity_specs, + ) + + @override + @classmethod + def empty_instance(cls) -> LinkableSpecSet: + return LinkableSpecSet() + + def dedupe(self) -> LinkableSpecSet: # noqa: D + # Use dictionaries to dedupe as it preserves insertion order. + + dimension_spec_dict: Dict[DimensionSpec, None] = {} + for dimension_spec in self.dimension_specs: + dimension_spec_dict[dimension_spec] = None + + time_dimension_spec_dict: Dict[TimeDimensionSpec, None] = {} + for time_dimension_spec in self.time_dimension_specs: + time_dimension_spec_dict[time_dimension_spec] = None + + entity_spec_dict: Dict[EntitySpec, None] = {} + for entity_spec in self.entity_specs: + entity_spec_dict[entity_spec] = None return LinkableSpecSet( - dimension_specs=tuple(dimension_specs), - time_dimension_specs=tuple(time_dimension_specs), - entity_specs=tuple(entity_specs), + dimension_specs=tuple(dimension_spec_dict.keys()), + time_dimension_specs=tuple(time_dimension_spec_dict.keys()), + entity_specs=tuple(entity_spec_dict.keys()), ) def is_subset_of(self, other_set: LinkableSpecSet) -> bool: # noqa: D @@ -582,7 +594,7 @@ def transform(self, spec_set: InstanceSpecSet) -> TransformOutputT: # noqa: D @dataclass(frozen=True) -class InstanceSpecSet(SerializableDataclass): +class InstanceSpecSet(Mergeable, SerializableDataclass): """Consolidates all specs used in an instance set.""" metric_specs: Tuple[MetricSpec, ...] = () @@ -592,18 +604,22 @@ class InstanceSpecSet(SerializableDataclass): time_dimension_specs: Tuple[TimeDimensionSpec, ...] = () metadata_specs: Tuple[MetadataSpec, ...] = () - @staticmethod - def merge(others: Sequence[InstanceSpecSet]) -> InstanceSpecSet: - """Merge all sets into one set, without de-duplication.""" + @override + def merge(self, other: InstanceSpecSet) -> InstanceSpecSet: return InstanceSpecSet( - metric_specs=tuple(itertools.chain.from_iterable([x.metric_specs for x in others])), - measure_specs=tuple(itertools.chain.from_iterable([x.measure_specs for x in others])), - dimension_specs=tuple(itertools.chain.from_iterable([x.dimension_specs for x in others])), - entity_specs=tuple(itertools.chain.from_iterable([x.entity_specs for x in others])), - time_dimension_specs=tuple(itertools.chain.from_iterable([x.time_dimension_specs for x in others])), - metadata_specs=tuple(itertools.chain.from_iterable([x.metadata_specs for x in others])), + metric_specs=self.metric_specs + other.metric_specs, + measure_specs=self.measure_specs + other.measure_specs, + dimension_specs=self.dimension_specs + other.dimension_specs, + entity_specs=self.entity_specs + other.entity_specs, + time_dimension_specs=self.time_dimension_specs + other.time_dimension_specs, + metadata_specs=self.metadata_specs + other.metadata_specs, ) + @override + @classmethod + def empty_instance(cls) -> InstanceSpecSet: + return InstanceSpecSet() + def dedupe(self) -> InstanceSpecSet: """De-duplicates repeated elements. @@ -665,7 +681,9 @@ def transform(self, transform_function: InstanceSpecSetTransform[TransformOutput @staticmethod def create_from_linkable_specs(linkable_specs: Sequence[LinkableInstanceSpec]) -> InstanceSpecSet: # noqa: D - return InstanceSpecSet.merge(tuple(x.as_linkable_spec_set.as_spec_set for x in linkable_specs)) + return InstanceSpecSet.merge_iterable( + tuple(linkable_spec.as_linkable_spec_set.as_spec_set for linkable_spec in linkable_specs) + ) @dataclass(frozen=True) @@ -716,5 +734,5 @@ def combine(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D return WhereFilterSpec( where_sql=f"({self.where_sql}) AND ({other.where_sql})", bind_parameters=self.bind_parameters.combine(other.bind_parameters), - linkable_spec_set=LinkableSpecSet.merge([self.linkable_spec_set, other.linkable_spec_set]), + linkable_spec_set=self.linkable_spec_set.merge(other.linkable_spec_set), ) diff --git a/metricflow/test/collections/__init__.py b/metricflow/test/collections/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/metricflow/test/collections/test_merger.py b/metricflow/test/collections/test_merger.py new file mode 100644 index 0000000000..a481206759 --- /dev/null +++ b/metricflow/test/collections/test_merger.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import List, Tuple + +from typing_extensions import override + +from metricflow.collections.merger import Mergeable + + +@dataclass(frozen=True) +class NumberTuple(Mergeable): # noqa: D + numbers: Tuple[int, ...] = field(default_factory=tuple) + + @override + def merge(self: NumberTuple, other: NumberTuple) -> NumberTuple: + return NumberTuple(self.numbers + other.numbers) + + @override + @classmethod + def empty_instance(cls) -> NumberTuple: + return NumberTuple() + + +def test_merger() -> None: # noqa: D + items_to_merge: List[NumberTuple] = [ + NumberTuple(()), + NumberTuple((1,)), + NumberTuple((2, 3)), + ] + + assert NumberTuple.merge_iterable(items_to_merge) == NumberTuple((1, 2, 3)) diff --git a/metricflow/test/test_specs.py b/metricflow/test/test_specs.py index 172b07fa7a..8b1aca8109 100644 --- a/metricflow/test/test_specs.py +++ b/metricflow/test/test_specs.py @@ -111,7 +111,7 @@ def test_merge_spec_set() -> None: # noqa: D dimension_specs=(DimensionSpec(element_name="is_instant", entity_links=(EntityReference("booking"),)),) ) - assert InstanceSpecSet.merge((spec_set1, spec_set2)) == InstanceSpecSet( + assert spec_set1.merge(spec_set2) == InstanceSpecSet( metric_specs=(MetricSpec(element_name="bookings"),), dimension_specs=(DimensionSpec(element_name="is_instant", entity_links=(EntityReference("booking"),)),), )