Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mergeable to Simplify Merging Collections #790

Merged
merged 3 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
66 changes: 66 additions & 0 deletions metricflow/collections/merger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

import functools
from abc import ABC, abstractmethod
from typing import Iterable, Type, TypeVar

from typing_extensions import Self

MergeableT = TypeVar("MergeableT", bound="Mergeable")


class Mergeable(ABC):
"""Objects that can be merged together to form a superset object of the same type.

Merging objects are frequently needed in MetricFlow as there are several recursive operations where the output is
the superset of the result of the recursive calls.

e.g.
* The validation issue set of a derived metric includes the issues of the parent metrics.
* The output of a node in the dataflow plan can include the outputs of the parent nodes.
* The query-time where filter is useful to combine with the metric-defined where filter.

Having a common interface also gives a consistent name to this operation so that we don't end up with multiple names
to describe this operation (e.g. combine, add, concat).

This is used to streamline the following case that occurs in the codebase:

items_to_merge: List[ItemType] = []
...
if ...
items_to_merge.append(...)
...
if ...
items_to_merge.append(...)
...
if ...
...
items_to_merge.append(...)
return merge_items(items_to_merge)
...
return merge_items(items_to_merge)

This centralizes the definition of the merge_items() call.
"""

@abstractmethod
def merge(self: Self, other: Self) -> Self:
"""Return a new object that is the result of merging self with other."""
raise NotImplementedError

@classmethod
@abstractmethod
def empty_instance(cls: Type[MergeableT]) -> MergeableT:
"""Create an empty instance to handle merging of empty sequences of items.

As merge_iterable() returns an empty instance for an empty iterable, there needs to be a way of creating one.
"""
raise NotImplementedError

@classmethod
def merge_iterable(cls: Type[MergeableT], items: Iterable[MergeableT]) -> MergeableT:
"""Merge all items into a single instance.

If an empty iterable has been passed in, this returns an empty instance.
"""
return functools.reduce(cls.merge, items, cls.empty_instance())
33 changes: 12 additions & 21 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,17 +674,15 @@ def _build_aggregated_measure_from_measure_source_node(

# Extraneous linkable specs are specs that are used in this phase that should not show up in the final result
# unless it was already a requested spec in the query
extraneous_linkable_specs = LinkableSpecSet()
linkable_spec_sets_to_merge: List[LinkableSpecSet] = []
if where_constraint:
extraneous_linkable_specs = LinkableSpecSet.merge(
(extraneous_linkable_specs, where_constraint.linkable_spec_set)
)
linkable_spec_sets_to_merge.append(where_constraint.linkable_spec_set)
if non_additive_dimension_spec:
extraneous_linkable_specs = LinkableSpecSet.merge(
(extraneous_linkable_specs, non_additive_dimension_spec.linkable_specs)
)
linkable_spec_sets_to_merge.append(non_additive_dimension_spec.linkable_specs)

extraneous_linkable_specs = LinkableSpecSet.merge_iterable(linkable_spec_sets_to_merge).dedupe()
required_linkable_specs = queried_linkable_specs.merge(extraneous_linkable_specs).dedupe()

required_linkable_specs = LinkableSpecSet.merge((queried_linkable_specs, extraneous_linkable_specs))
logger.info(
f"Looking for a recipe to get:\n"
f"{pformat_big_objects(measure_specs=[measure_spec], required_linkable_set=required_linkable_specs)}"
Expand Down Expand Up @@ -736,11 +734,8 @@ def _build_aggregated_measure_from_measure_source_node(
# Only get the required measure and the local linkable instances so that aggregations work correctly.
filtered_measure_source_node = FilterElementsNode(
parent_node=join_to_time_spine_node or time_range_node or measure_recipe.source_node,
include_specs=InstanceSpecSet.merge(
(
InstanceSpecSet(measure_specs=(measure_spec,)),
InstanceSpecSet.create_from_linkable_specs(measure_recipe.required_local_linkable_specs),
)
include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(
InstanceSpecSet.create_from_linkable_specs(measure_recipe.required_local_linkable_specs),
),
)

Expand All @@ -752,11 +747,8 @@ def _build_aggregated_measure_from_measure_source_node(
join_targets=join_targets,
)

specs_to_keep_after_join = InstanceSpecSet.merge(
(
InstanceSpecSet(measure_specs=(measure_spec,)),
required_linkable_specs.as_spec_set,
)
specs_to_keep_after_join = InstanceSpecSet(measure_specs=(measure_spec,)).merge(
required_linkable_specs.as_spec_set,
)

after_join_filtered_node = FilterElementsNode(
Expand Down Expand Up @@ -814,10 +806,9 @@ def _build_aggregated_measure_from_measure_source_node(
# e.g. for "bookings" by "ds" where "is_instant", "is_instant" should not be in the results.
pre_aggregate_node = FilterElementsNode(
parent_node=pre_aggregate_node,
include_specs=InstanceSpecSet.merge(
(InstanceSpecSet(measure_specs=(measure_spec,)), queried_linkable_specs.as_spec_set)
),
include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(queried_linkable_specs.as_spec_set),
)

aggregate_measures_node = AggregateMeasuresNode(
parent_node=pre_aggregate_node,
metric_input_measure_specs=(metric_input_measure_spec,),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
WriteToResultTableNode,
)
from metricflow.dataflow.optimizer.source_scan.matching_linkable_specs import MatchingLinkableSpecsTransform
from metricflow.specs.specs import InstanceSpecSet

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -360,9 +359,7 @@ def visit_pass_elements_filter_node( # noqa: D
# specs since any branch that is merged together needs to output the same set of dimensions.
combined_node = FilterElementsNode(
parent_node=combined_parent_node,
include_specs=InstanceSpecSet.merge(
(self._current_left_node.include_specs, current_right_node.include_specs)
).dedupe(),
include_specs=self._current_left_node.include_specs.merge(current_right_node.include_specs).dedupe(),
)
self._log_combine_success(
left_node=self._current_left_node,
Expand Down
86 changes: 52 additions & 34 deletions metricflow/specs/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from hashlib import sha1
from typing import Any, Generic, List, Optional, Sequence, Tuple, TypeVar
from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar

from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass
from dbt_semantic_interfaces.implementations.metric import PydanticMetricTimeWindow
Expand All @@ -28,8 +28,10 @@
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow.aggregation_properties import AggregationState
from metricflow.collections.merger import Mergeable
from metricflow.filters.time_constraint import TimeRangeConstraint
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.sql.sql_bind_parameters import SqlBindParameters
Expand Down Expand Up @@ -491,7 +493,7 @@ class FilterSpec(SerializableDataclass): # noqa: D


@dataclass(frozen=True)
class LinkableSpecSet(SerializableDataclass):
class LinkableSpecSet(Mergeable, SerializableDataclass):
"""Groups linkable specs."""

dimension_specs: Tuple[DimensionSpec, ...] = ()
Expand All @@ -502,28 +504,38 @@ class LinkableSpecSet(SerializableDataclass):
def as_tuple(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D
return tuple(itertools.chain(self.dimension_specs, self.time_dimension_specs, self.entity_specs))

@staticmethod
def merge(spec_sets: Sequence[LinkableSpecSet]) -> LinkableSpecSet:
"""Merges and dedupes the linkable specs."""
dimension_specs: List[DimensionSpec] = []
time_dimension_specs: List[TimeDimensionSpec] = []
entity_specs: List[EntitySpec] = []

for spec_set in spec_sets:
for dimension_spec in spec_set.dimension_specs:
if dimension_spec not in dimension_specs:
dimension_specs.append(dimension_spec)
for time_dimension_spec in spec_set.time_dimension_specs:
if time_dimension_spec not in time_dimension_specs:
time_dimension_specs.append(time_dimension_spec)
for entity_spec in spec_set.entity_specs:
if entity_spec not in entity_specs:
entity_specs.append(entity_spec)
@override
def merge(self, other: LinkableSpecSet) -> LinkableSpecSet:
return LinkableSpecSet(
dimension_specs=self.dimension_specs + other.dimension_specs,
time_dimension_specs=self.time_dimension_specs + other.time_dimension_specs,
entity_specs=self.entity_specs + other.entity_specs,
)

@override
@classmethod
def empty_instance(cls) -> LinkableSpecSet:
return LinkableSpecSet()

def dedupe(self) -> LinkableSpecSet: # noqa: D
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we separating these? Is there ever a case where we want duplicates in this output?

If not, this feels like we're adding an old bug factory back to the codebase.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preferring to make dedupe() a separate method as it makes the Mergeable interface more flexible. I disagree on adding an old bug factory back to the codebase. If there is always de-duplication behavior, it could hide bugs in testing. e.g. the output of a merged result is checked in a test, and it's not expected for the sources of the to produce duplicate items. However, the de-duplication of the merged result hides that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first two words (if not) in that bug factory sentence are extremely important. If we never want duplicates to come out of a merge, then any case where we have them is a bug, and requiring developers to manually call a dedupe method can only lead to bugs, while encapsulating that behavior can only prevent them. If there is an upstream bug that the encapsulation hides that's a separate problem, and it's one that having developers blindly call dedupe() on their merge outputs (which is a thing we should expect to happen if our contributor base grows) will also mask.

As it stands right now every merge call either explicitly requests deduped results or implicitly resolves to them, which suggests that we shouldn't force developers to call it in order to get the behavior they need.

Ultimately, this isn't a big deal, but it's interesting to me that the only extant cases where we don't call dedupe() are still cases where we need deduplicated results.

# Use dictionaries to dedupe as it preserves insertion order.

dimension_spec_dict: Dict[DimensionSpec, None] = {}
for dimension_spec in self.dimension_specs:
dimension_spec_dict[dimension_spec] = None

time_dimension_spec_dict: Dict[TimeDimensionSpec, None] = {}
for time_dimension_spec in self.time_dimension_specs:
time_dimension_spec_dict[time_dimension_spec] = None

entity_spec_dict: Dict[EntitySpec, None] = {}
for entity_spec in self.entity_specs:
entity_spec_dict[entity_spec] = None

return LinkableSpecSet(
dimension_specs=tuple(dimension_specs),
time_dimension_specs=tuple(time_dimension_specs),
entity_specs=tuple(entity_specs),
dimension_specs=tuple(dimension_spec_dict.keys()),
time_dimension_specs=tuple(time_dimension_spec_dict.keys()),
entity_specs=tuple(entity_spec_dict.keys()),
)

def is_subset_of(self, other_set: LinkableSpecSet) -> bool: # noqa: D
Expand Down Expand Up @@ -582,7 +594,7 @@ def transform(self, spec_set: InstanceSpecSet) -> TransformOutputT: # noqa: D


@dataclass(frozen=True)
class InstanceSpecSet(SerializableDataclass):
class InstanceSpecSet(Mergeable, SerializableDataclass):
"""Consolidates all specs used in an instance set."""

metric_specs: Tuple[MetricSpec, ...] = ()
Expand All @@ -592,18 +604,22 @@ class InstanceSpecSet(SerializableDataclass):
time_dimension_specs: Tuple[TimeDimensionSpec, ...] = ()
metadata_specs: Tuple[MetadataSpec, ...] = ()

@staticmethod
def merge(others: Sequence[InstanceSpecSet]) -> InstanceSpecSet:
"""Merge all sets into one set, without de-duplication."""
@override
def merge(self, other: InstanceSpecSet) -> InstanceSpecSet:
return InstanceSpecSet(
metric_specs=tuple(itertools.chain.from_iterable([x.metric_specs for x in others])),
measure_specs=tuple(itertools.chain.from_iterable([x.measure_specs for x in others])),
dimension_specs=tuple(itertools.chain.from_iterable([x.dimension_specs for x in others])),
entity_specs=tuple(itertools.chain.from_iterable([x.entity_specs for x in others])),
time_dimension_specs=tuple(itertools.chain.from_iterable([x.time_dimension_specs for x in others])),
metadata_specs=tuple(itertools.chain.from_iterable([x.metadata_specs for x in others])),
metric_specs=self.metric_specs + other.metric_specs,
measure_specs=self.measure_specs + other.measure_specs,
dimension_specs=self.dimension_specs + other.dimension_specs,
entity_specs=self.entity_specs + other.entity_specs,
time_dimension_specs=self.time_dimension_specs + other.time_dimension_specs,
metadata_specs=self.metadata_specs + other.metadata_specs,
)

@override
@classmethod
def empty_instance(cls) -> InstanceSpecSet:
return InstanceSpecSet()

def dedupe(self) -> InstanceSpecSet:
"""De-duplicates repeated elements.

Expand Down Expand Up @@ -665,7 +681,9 @@ def transform(self, transform_function: InstanceSpecSetTransform[TransformOutput

@staticmethod
def create_from_linkable_specs(linkable_specs: Sequence[LinkableInstanceSpec]) -> InstanceSpecSet: # noqa: D
return InstanceSpecSet.merge(tuple(x.as_linkable_spec_set.as_spec_set for x in linkable_specs))
return InstanceSpecSet.merge_iterable(
tuple(linkable_spec.as_linkable_spec_set.as_spec_set for linkable_spec in linkable_specs)
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -716,5 +734,5 @@ def combine(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D
return WhereFilterSpec(
where_sql=f"({self.where_sql}) AND ({other.where_sql})",
bind_parameters=self.bind_parameters.combine(other.bind_parameters),
linkable_spec_set=LinkableSpecSet.merge([self.linkable_spec_set, other.linkable_spec_set]),
linkable_spec_set=self.linkable_spec_set.merge(other.linkable_spec_set),
)
Empty file.
32 changes: 32 additions & 0 deletions metricflow/test/collections/test_merger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Tuple

from typing_extensions import override

from metricflow.collections.merger import Mergeable


@dataclass(frozen=True)
class NumberTuple(Mergeable): # noqa: D
numbers: Tuple[int, ...] = field(default_factory=tuple)

@override
def merge(self: NumberTuple, other: NumberTuple) -> NumberTuple:
return NumberTuple(self.numbers + other.numbers)

@override
@classmethod
def empty_instance(cls) -> NumberTuple:
return NumberTuple()


def test_merger() -> None: # noqa: D
items_to_merge: List[NumberTuple] = [
NumberTuple(()),
NumberTuple((1,)),
NumberTuple((2, 3)),
]

assert NumberTuple.merge_iterable(items_to_merge) == NumberTuple((1, 2, 3))
2 changes: 1 addition & 1 deletion metricflow/test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_merge_spec_set() -> None: # noqa: D
dimension_specs=(DimensionSpec(element_name="is_instant", entity_links=(EntityReference("booking"),)),)
)

assert InstanceSpecSet.merge((spec_set1, spec_set2)) == InstanceSpecSet(
assert spec_set1.merge(spec_set2) == InstanceSpecSet(
metric_specs=(MetricSpec(element_name="bookings"),),
dimension_specs=(DimensionSpec(element_name="is_instant", entity_links=(EntityReference("booking"),)),),
)
Expand Down