Skip to content

Commit

Permalink
/* PR_START p--misc 06 */ Implement Mergeable for `SqlExpressionTre…
Browse files Browse the repository at this point in the history
…eLineage`
  • Loading branch information
plypaul committed Dec 13, 2024
1 parent 438f228 commit bcf6209
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 24 deletions.
4 changes: 2 additions & 2 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _reduce_parents(

@staticmethod
def _statement_contains_difficult_expressions(node: SqlSelectStatementNode) -> bool:
combined_lineage = SqlExpressionTreeLineage.combine(
combined_lineage = SqlExpressionTreeLineage.merge_iterable(
tuple(x.expr.lineage for x in node.select_columns)
+ ((node.where.lineage,) if node.where else ())
+ tuple(x.expr.lineage for x in node.group_bys)
Expand All @@ -133,7 +133,7 @@ def _statement_contains_difficult_expressions(node: SqlSelectStatementNode) -> b

@staticmethod
def _select_columns_contain_string_expressions(select_columns: Tuple[SqlSelectColumn, ...]) -> bool:
combined_lineage = SqlExpressionTreeLineage.combine(tuple(x.expr.lineage for x in select_columns))
combined_lineage = SqlExpressionTreeLineage.merge_iterable(tuple(x.expr.lineage for x in select_columns))

return len(combined_lineage.string_exprs) > 0

Expand Down
2 changes: 1 addition & 1 deletion metricflow/sql/optimizer/tag_required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _search_for_expressions(
if select_node.where:
all_expr_search_results.append(select_node.where.lineage)

return SqlExpressionTreeLineage.combine(all_expr_search_results)
return SqlExpressionTreeLineage.merge_iterable(all_expr_search_results)

@override
def visit_cte_node(self, node: SqlCteNode) -> None:
Expand Down
61 changes: 40 additions & 21 deletions metricflow/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Generic, List, Mapping, Optional, Sequence, Tuple
from typing import Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple

import more_itertools
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
Expand All @@ -15,6 +15,7 @@
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
Expand Down Expand Up @@ -100,7 +101,7 @@ def matches(self, other: SqlExpressionNode) -> bool:


@dataclass(frozen=True)
class SqlExpressionTreeLineage:
class SqlExpressionTreeLineage(Mergeable):
"""Captures the lineage of an expression node - contains itself and all ancestor nodes."""

string_exprs: Tuple[SqlStringExpression, ...] = ()
Expand All @@ -109,17 +110,18 @@ class SqlExpressionTreeLineage:
column_alias_reference_exprs: Tuple[SqlColumnAliasReferenceExpression, ...] = ()
other_exprs: Tuple[SqlExpressionNode, ...] = ()

@staticmethod
def combine(lineages: Sequence[SqlExpressionTreeLineage]) -> SqlExpressionTreeLineage:
@classmethod
@override
def merge_iterable(cls, items: Iterable[SqlExpressionTreeLineage]) -> SqlExpressionTreeLineage:
"""Combine multiple lineages into one lineage, without de-duping."""
return SqlExpressionTreeLineage(
string_exprs=tuple(more_itertools.flatten(tuple(x.string_exprs for x in lineages))),
function_exprs=tuple(more_itertools.flatten(tuple(x.function_exprs for x in lineages))),
column_reference_exprs=tuple(more_itertools.flatten(tuple(x.column_reference_exprs for x in lineages))),
string_exprs=tuple(more_itertools.flatten(tuple(x.string_exprs for x in items))),
function_exprs=tuple(more_itertools.flatten(tuple(x.function_exprs for x in items))),
column_reference_exprs=tuple(more_itertools.flatten(tuple(x.column_reference_exprs for x in items))),
column_alias_reference_exprs=tuple(
more_itertools.flatten(tuple(x.column_alias_reference_exprs for x in lineages))
more_itertools.flatten(tuple(x.column_alias_reference_exprs for x in items))
),
other_exprs=tuple(more_itertools.flatten(tuple(x.other_exprs for x in lineages))),
other_exprs=tuple(more_itertools.flatten(tuple(x.other_exprs for x in items))),
)

@property
Expand All @@ -138,6 +140,21 @@ def contains_ambiguous_exprs(self) -> bool: # noqa: D102
def contains_aggregate_exprs(self) -> bool: # noqa: D102
return any(x.is_aggregate_function for x in self.function_exprs)

@override
def merge(self, other: SqlExpressionTreeLineage) -> SqlExpressionTreeLineage:
return SqlExpressionTreeLineage(
string_exprs=self.string_exprs + other.string_exprs,
function_exprs=self.function_exprs + other.function_exprs,
column_reference_exprs=self.column_reference_exprs + other.column_reference_exprs,
column_alias_reference_exprs=self.column_alias_reference_exprs + other.column_alias_reference_exprs,
other_exprs=self.other_exprs + other.other_exprs,
)

@classmethod
@override
def empty_instance(cls) -> SqlExpressionTreeLineage:
return SqlExpressionTreeLineage()


class SqlColumnReplacements:
"""When re-writing column references in expressions, this stores the mapping."""
Expand Down Expand Up @@ -604,7 +621,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -803,7 +820,7 @@ def is_aggregate_function(self) -> bool: # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(function_exprs=(self,)),)
)

Expand Down Expand Up @@ -923,7 +940,7 @@ def is_aggregate_function(self) -> bool: # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(function_exprs=(self,)),)
)

Expand Down Expand Up @@ -1084,7 +1101,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(function_exprs=(self,)),)
)

Expand Down Expand Up @@ -1194,7 +1211,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1241,7 +1258,9 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine([self.arg.lineage, SqlExpressionTreeLineage(other_exprs=(self,))])
return SqlExpressionTreeLineage.merge_iterable(
[self.arg.lineage, SqlExpressionTreeLineage(other_exprs=(self,))]
)

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlIsNullExpression):
Expand Down Expand Up @@ -1304,7 +1323,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1351,7 +1370,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1402,7 +1421,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1461,7 +1480,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1526,7 +1545,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down Expand Up @@ -1591,7 +1610,7 @@ def rewrite( # noqa: D102

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

Expand Down

0 comments on commit bcf6209

Please sign in to comment.