Skip to content

Commit

Permalink
Implement Mergeable for InstanceSpecSet.
Browse files Browse the repository at this point in the history
InstanceSpecSet already has a merge() call, so this updates the class to
imlement the Mergeable interface for consistency.
  • Loading branch information
plypaul committed Nov 14, 2023
1 parent c7386e7 commit 7a06886
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 29 deletions.
19 changes: 6 additions & 13 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,11 +736,8 @@ def _build_aggregated_measure_from_measure_source_node(
# Only get the required measure and the local linkable instances so that aggregations work correctly.
filtered_measure_source_node = FilterElementsNode(
parent_node=join_to_time_spine_node or time_range_node or measure_recipe.source_node,
include_specs=InstanceSpecSet.merge(
(
InstanceSpecSet(measure_specs=(measure_spec,)),
InstanceSpecSet.create_from_linkable_specs(measure_recipe.required_local_linkable_specs),
)
include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(
InstanceSpecSet.create_from_linkable_specs(measure_recipe.required_local_linkable_specs),
),
)

Expand All @@ -752,11 +749,8 @@ def _build_aggregated_measure_from_measure_source_node(
join_targets=join_targets,
)

specs_to_keep_after_join = InstanceSpecSet.merge(
(
InstanceSpecSet(measure_specs=(measure_spec,)),
required_linkable_specs.as_spec_set,
)
specs_to_keep_after_join = InstanceSpecSet(measure_specs=(measure_spec,)).merge(
required_linkable_specs.as_spec_set,
)

after_join_filtered_node = FilterElementsNode(
Expand Down Expand Up @@ -814,10 +808,9 @@ def _build_aggregated_measure_from_measure_source_node(
# e.g. for "bookings" by "ds" where "is_instant", "is_instant" should not be in the results.
pre_aggregate_node = FilterElementsNode(
parent_node=pre_aggregate_node,
include_specs=InstanceSpecSet.merge(
(InstanceSpecSet(measure_specs=(measure_spec,)), queried_linkable_specs.as_spec_set)
),
include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(queried_linkable_specs.as_spec_set),
)

aggregate_measures_node = AggregateMeasuresNode(
parent_node=pre_aggregate_node,
metric_input_measure_specs=(metric_input_measure_spec,),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
WriteToResultTableNode,
)
from metricflow.dataflow.optimizer.source_scan.matching_linkable_specs import MatchingLinkableSpecsTransform
from metricflow.specs.specs import InstanceSpecSet

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -360,9 +359,7 @@ def visit_pass_elements_filter_node( # noqa: D
# specs since any branch that is merged together needs to output the same set of dimensions.
combined_node = FilterElementsNode(
parent_node=combined_parent_node,
include_specs=InstanceSpecSet.merge(
(self._current_left_node.include_specs, current_right_node.include_specs)
).dedupe(),
include_specs=self._current_left_node.include_specs.merge(current_right_node.include_specs).dedupe(),
)
self._log_combine_success(
left_node=self._current_left_node,
Expand Down
30 changes: 19 additions & 11 deletions metricflow/specs/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow.aggregation_properties import AggregationState
from metricflow.collections.merger import Mergeable
from metricflow.filters.time_constraint import TimeRangeConstraint
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.sql.sql_bind_parameters import SqlBindParameters
Expand Down Expand Up @@ -582,7 +584,7 @@ def transform(self, spec_set: InstanceSpecSet) -> TransformOutputT: # noqa: D


@dataclass(frozen=True)
class InstanceSpecSet(SerializableDataclass):
class InstanceSpecSet(Mergeable, SerializableDataclass):
"""Consolidates all specs used in an instance set."""

metric_specs: Tuple[MetricSpec, ...] = ()
Expand All @@ -592,18 +594,22 @@ class InstanceSpecSet(SerializableDataclass):
time_dimension_specs: Tuple[TimeDimensionSpec, ...] = ()
metadata_specs: Tuple[MetadataSpec, ...] = ()

@staticmethod
def merge(others: Sequence[InstanceSpecSet]) -> InstanceSpecSet:
"""Merge all sets into one set, without de-duplication."""
@override
def merge(self, other: InstanceSpecSet) -> InstanceSpecSet:
return InstanceSpecSet(
metric_specs=tuple(itertools.chain.from_iterable([x.metric_specs for x in others])),
measure_specs=tuple(itertools.chain.from_iterable([x.measure_specs for x in others])),
dimension_specs=tuple(itertools.chain.from_iterable([x.dimension_specs for x in others])),
entity_specs=tuple(itertools.chain.from_iterable([x.entity_specs for x in others])),
time_dimension_specs=tuple(itertools.chain.from_iterable([x.time_dimension_specs for x in others])),
metadata_specs=tuple(itertools.chain.from_iterable([x.metadata_specs for x in others])),
metric_specs=self.metric_specs + other.metric_specs,
measure_specs=self.measure_specs + other.measure_specs,
dimension_specs=self.dimension_specs + other.dimension_specs,
entity_specs=self.entity_specs + other.entity_specs,
time_dimension_specs=self.time_dimension_specs + other.time_dimension_specs,
metadata_specs=self.metadata_specs + other.metadata_specs,
)

@override
@classmethod
def empty_instance(cls) -> InstanceSpecSet:
return InstanceSpecSet()

def dedupe(self) -> InstanceSpecSet:
"""De-duplicates repeated elements.
Expand Down Expand Up @@ -665,7 +671,9 @@ def transform(self, transform_function: InstanceSpecSetTransform[TransformOutput

@staticmethod
def create_from_linkable_specs(linkable_specs: Sequence[LinkableInstanceSpec]) -> InstanceSpecSet: # noqa: D
return InstanceSpecSet.merge(tuple(x.as_linkable_spec_set.as_spec_set for x in linkable_specs))
return InstanceSpecSet.merge_iterable(
tuple(linkable_spec.as_linkable_spec_set.as_spec_set for linkable_spec in linkable_specs)
)


@dataclass(frozen=True)
Expand Down
2 changes: 1 addition & 1 deletion metricflow/test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_merge_spec_set() -> None: # noqa: D
dimension_specs=(DimensionSpec(element_name="is_instant", entity_links=(EntityReference("booking"),)),)
)

assert InstanceSpecSet.merge((spec_set1, spec_set2)) == InstanceSpecSet(
assert spec_set1.merge(spec_set2) == InstanceSpecSet(
metric_specs=(MetricSpec(element_name="bookings"),),
dimension_specs=(DimensionSpec(element_name="is_instant", entity_links=(EntityReference("booking"),)),),
)
Expand Down

0 comments on commit 7a06886

Please sign in to comment.