Skip to content

Commit

Permalink
Implement Mergeable for SqlExpressionTreeLineage (#1570)
Browse files Browse the repository at this point in the history
This PR implements `Mergeable` for `SqlExpressionTreeLineage` to reuse
code available in `Mergeable`.
  • Loading branch information
plypaul authored Dec 14, 2024
1 parent bc416f4 commit ebd7073
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 30 deletions.
4 changes: 2 additions & 2 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _reduce_parents(

@staticmethod
def _statement_contains_difficult_expressions(node: SqlSelectStatementNode) -> bool:
combined_lineage = SqlExpressionTreeLineage.combine(
combined_lineage = SqlExpressionTreeLineage.merge_iterable(
tuple(x.expr.lineage for x in node.select_columns)
+ ((node.where.lineage,) if node.where else ())
+ tuple(x.expr.lineage for x in node.group_bys)
Expand All @@ -133,7 +133,7 @@ def _statement_contains_difficult_expressions(node: SqlSelectStatementNode) -> b

@staticmethod
def _select_columns_contain_string_expressions(select_columns: Tuple[SqlSelectColumn, ...]) -> bool:
combined_lineage = SqlExpressionTreeLineage.combine(tuple(x.expr.lineage for x in select_columns))
combined_lineage = SqlExpressionTreeLineage.merge_iterable(tuple(x.expr.lineage for x in select_columns))

return len(combined_lineage.string_exprs) > 0

Expand Down
2 changes: 1 addition & 1 deletion metricflow/sql/optimizer/tag_required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _search_for_expressions(
if select_node.where:
all_expr_search_results.append(select_node.where.lineage)

return SqlExpressionTreeLineage.combine(all_expr_search_results)
return SqlExpressionTreeLineage.merge_iterable(all_expr_search_results)

@override
def visit_cte_node(self, node: SqlCteNode) -> None:
Expand Down
58 changes: 31 additions & 27 deletions metricflow/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from enum import Enum
from typing import Dict, Generic, List, Mapping, Optional, Sequence, Tuple

import more_itertools
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols.measure import MeasureAggregationParameters
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
Expand Down Expand Up @@ -100,7 +100,7 @@ def matches(self, other: SqlExpressionNode) -> bool:


@dataclass(frozen=True)
class SqlExpressionTreeLineage:
class SqlExpressionTreeLineage(Mergeable):
"""Captures the lineage of an expression node - contains itself and all ancestor nodes."""

string_exprs: Tuple[SqlStringExpression, ...] = ()
Expand All @@ -109,19 +109,6 @@ class SqlExpressionTreeLineage:
column_alias_reference_exprs: Tuple[SqlColumnAliasReferenceExpression, ...] = ()
other_exprs: Tuple[SqlExpressionNode, ...] = ()

@staticmethod
def combine(lineages: Sequence[SqlExpressionTreeLineage]) -> SqlExpressionTreeLineage:
"""Combine multiple lineages into one lineage, without de-duping."""
return SqlExpressionTreeLineage(
string_exprs=tuple(more_itertools.flatten(tuple(x.string_exprs for x in lineages))),
function_exprs=tuple(more_itertools.flatten(tuple(x.function_exprs for x in lineages))),
column_reference_exprs=tuple(more_itertools.flatten(tuple(x.column_reference_exprs for x in lineages))),
column_alias_reference_exprs=tuple(
more_itertools.flatten(tuple(x.column_alias_reference_exprs for x in lineages))
),
other_exprs=tuple(more_itertools.flatten(tuple(x.other_exprs for x in lineages))),
)

@property
def contains_string_exprs(self) -> bool: # noqa: D102
return len(self.string_exprs) > 0
Expand All @@ -138,6 +125,21 @@ def contains_ambiguous_exprs(self) -> bool: # noqa: D102
def contains_aggregate_exprs(self) -> bool: # noqa: D102
return any(x.is_aggregate_function for x in self.function_exprs)

@override
def merge(self, other: SqlExpressionTreeLineage) -> SqlExpressionTreeLineage:
return SqlExpressionTreeLineage(
string_exprs=self.string_exprs + other.string_exprs,
function_exprs=self.function_exprs + other.function_exprs,
column_reference_exprs=self.column_reference_exprs + other.column_reference_exprs,
column_alias_reference_exprs=self.column_alias_reference_exprs + other.column_alias_reference_exprs,
other_exprs=self.other_exprs + other.other_exprs,
)

@classmethod
@override
def empty_instance(cls) -> SqlExpressionTreeLineage:
return SqlExpressionTreeLineage()


class SqlColumnReplacements:
"""When re-writing column references in expressions, this stores the mapping."""
Expand Down Expand Up @@ -604,7 +606,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -803,7 +805,7 @@ def is_aggregate_function(self) -> bool: # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(function_exprs=(self,)),)
)

Expand Down Expand Up @@ -923,7 +925,7 @@ def is_aggregate_function(self) -> bool: # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(function_exprs=(self,)),)
)

Expand Down Expand Up @@ -1084,7 +1086,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(function_exprs=(self,)),)
)

Expand Down Expand Up @@ -1194,7 +1196,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1241,7 +1243,9 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine([self.arg.lineage, SqlExpressionTreeLineage(other_exprs=(self,))])
return SqlExpressionTreeLineage.merge_iterable(
[self.arg.lineage, SqlExpressionTreeLineage(other_exprs=(self,))]
)

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlIsNullExpression):
Expand Down Expand Up @@ -1304,7 +1308,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1351,7 +1355,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1402,7 +1406,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1461,7 +1465,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1526,7 +1530,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1591,7 +1595,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down

0 comments on commit ebd7073

Please sign in to comment.