Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SQL exprs needed for custom offset window #1575

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
SQL_EXPR_BETWEEN_PREFIX = "betw"
SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc"
SQL_EXPR_GENERATE_UUID_PREFIX = "uuid"
SQL_EXPR_CASE_PREFIX = "case"
SQL_EXPR_ARITHMETIC_PREFIX = "arit"
SQL_EXPR_INTEGER_PREFIX = "int"

SQL_PLAN_SELECT_STATEMENT_ID_PREFIX = "ss"
SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet
from metricflow_semantics.visitor import Visitable, VisitorOutputT
from typing_extensions import override


@dataclass(frozen=True, eq=False)
Expand Down Expand Up @@ -237,6 +238,18 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> Visit
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_case_expr(self, node: SqlCaseExpression) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_arithmetic_expr(self, node: SqlArithmeticExpression) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_integer_expr(self, node: SqlIntegerExpression) -> VisitorOutputT: # noqa: D102
pass


@dataclass(frozen=True, eq=False)
class SqlStringExpression(SqlExpressionNode):
Expand Down Expand Up @@ -375,6 +388,59 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self.literal_value == other.literal_value


@dataclass(frozen=True, eq=False)
class SqlIntegerExpression(SqlExpressionNode):
"""An integer like 1."""

integer_value: int

@staticmethod
def create(integer_value: int) -> SqlIntegerExpression: # noqa: D102
return SqlIntegerExpression(parent_nodes=(), integer_value=integer_value)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.SQL_EXPR_INTEGER_PREFIX

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_integer_expr(self)

@property
def description(self) -> str: # noqa: D102
return f"Integer: {self.integer_value}"

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (DisplayedProperty("value", self.integer_value),)

@property
def requires_parenthesis(self) -> bool: # noqa: D102
return False

@property
def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102
return SqlBindParameterSet()

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(node_id={self.node_id}, integer_value={self.integer_value})"

def rewrite( # noqa: D102
self,
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return self

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage(other_exprs=(self,))

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlIntegerExpression):
return False
return self.integer_value == other.integer_value


@dataclass(frozen=True)
class SqlColumnReference:
"""Used with string expressions to specify what columns are referred to in the string expression."""
Expand Down Expand Up @@ -950,17 +1016,38 @@ class SqlWindowFunction(Enum):
FIRST_VALUE = "FIRST_VALUE"
LAST_VALUE = "LAST_VALUE"
AVERAGE = "AVG"
ROW_NUMBER = "ROW_NUMBER"
LAG = "LAG"

@property
def requires_ordering(self) -> bool:
"""Asserts whether or not ordering the window function will have an impact on the resulting value."""
if self is SqlWindowFunction.FIRST_VALUE or self is SqlWindowFunction.LAST_VALUE:
if (
self is SqlWindowFunction.FIRST_VALUE
or self is SqlWindowFunction.LAST_VALUE
or self is SqlWindowFunction.ROW_NUMBER
or self is SqlWindowFunction.LAG
):
return True
elif self is SqlWindowFunction.AVERAGE:
return False
else:
assert_values_exhausted(self)

@property
def allows_frame_clause(self) -> bool:
"""Whether the function allows a frame clause, e.g., 'ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING'."""
if (
self is SqlWindowFunction.FIRST_VALUE
or self is SqlWindowFunction.LAST_VALUE
or self is SqlWindowFunction.AVERAGE
):
return True
if self is SqlWindowFunction.ROW_NUMBER or self is SqlWindowFunction.LAG:
return False
else:
assert_values_exhausted(self)

@classmethod
def get_window_function_for_period_agg(cls, period_agg: PeriodAggregation) -> SqlWindowFunction:
"""Get the window function to use for given period agg option."""
Expand Down Expand Up @@ -1106,7 +1193,8 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return (
self.sql_function == other.sql_function
and self.order_by_args == other.order_by_args
and self._parents_match(other)
and self.partition_by_args == other.partition_by_args
and self.sql_function_args == other.sql_function_args
)


Expand Down Expand Up @@ -1367,7 +1455,7 @@ def rewrite( # noqa: D102
) -> SqlExpressionNode:
return SqlAddTimeExpression.create(
arg=self.arg.rewrite(column_replacements, should_render_table_alias),
count_expr=self.count_expr,
count_expr=self.count_expr.rewrite(column_replacements, should_render_table_alias),
granularity=self.granularity,
)

Expand Down Expand Up @@ -1719,3 +1807,158 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False


@dataclass(frozen=True, eq=False)
class SqlCaseExpression(SqlExpressionNode):
"""Renders a CASE WHEN expression."""

when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode]
else_expr: Optional[SqlExpressionNode]

@staticmethod
def create( # noqa: D102
when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode], else_expr: Optional[SqlExpressionNode] = None
) -> SqlCaseExpression:
parent_nodes: Tuple[SqlExpressionNode, ...] = ()
for when, then in when_to_then_exprs.items():
parent_nodes += (when,)
parent_nodes += (then,)

if else_expr:
parent_nodes += (else_expr,)

return SqlCaseExpression(parent_nodes=parent_nodes, when_to_then_exprs=when_to_then_exprs, else_expr=else_expr)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.SQL_EXPR_CASE_PREFIX

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_case_expr(self)

@property
def description(self) -> str: # noqa: D102
return "Case expression"

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return super().displayed_properties

@property
def requires_parenthesis(self) -> bool: # noqa: D102
return False

@property
def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102
return SqlBindParameterSet()

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(node_id={self.node_id})"

def rewrite( # noqa: D102
self,
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return SqlCaseExpression.create(
when_to_then_exprs={
when.rewrite(column_replacements, should_render_table_alias): then.rewrite(
column_replacements, should_render_table_alias
)
for when, then in self.when_to_then_exprs.items()
},
else_expr=(
self.else_expr.rewrite(column_replacements, should_render_table_alias) if self.else_expr else None
),
)

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlCaseExpression):
return False
return self.when_to_then_exprs == other.when_to_then_exprs and self.else_expr == other.else_expr


class SqlArithmeticOperator(Enum):
"""Arithmetic operator used to do math in a SQL expression."""

ADD = "+"
SUBTRACT = "-"
MULTIPLY = "*"
DIVIDE = "/"


@dataclass(frozen=True, eq=False)
class SqlArithmeticExpression(SqlExpressionNode):
"""An arithmetic expression using +, -, *, /.

e.g. my_table.my_column + my_table.other_column

Attributes:
left_expr: The expression on the left side of the operator
operator: The operator to use on the expressions
right_expr: The expression on the right side of the operator
"""

left_expr: SqlExpressionNode
operator: SqlArithmeticOperator
right_expr: SqlExpressionNode

@staticmethod
def create( # noqa: D102
left_expr: SqlExpressionNode, operator: SqlArithmeticOperator, right_expr: SqlExpressionNode
) -> SqlArithmeticExpression:
return SqlArithmeticExpression(
parent_nodes=(left_expr, right_expr), left_expr=left_expr, operator=operator, right_expr=right_expr
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.SQL_EXPR_ARITHMETIC_PREFIX

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_arithmetic_expr(self)

@property
def description(self) -> str: # noqa: D102
return "Arithmetic Expression"

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("left_expr", self.left_expr),
DisplayedProperty("operator", self.operator.value),
DisplayedProperty("right_expr", self.right_expr),
)

@property
def requires_parenthesis(self) -> bool: # noqa: D102
return True

def rewrite( # noqa: D102
self,
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return SqlArithmeticExpression.create(
left_expr=self.left_expr.rewrite(column_replacements, should_render_table_alias),
operator=self.operator,
right_expr=self.right_expr.rewrite(column_replacements, should_render_table_alias),
)

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.merge_iterable(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlArithmeticExpression):
return False
return self.operator == other.operator and self._parents_match(other)
14 changes: 7 additions & 7 deletions metricflow/dataset/convert_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,20 @@
from metricflow_semantics.specs.dimension_spec import DimensionSpec
from metricflow_semantics.specs.entity_spec import EntitySpec
from metricflow_semantics.specs.time_dimension_spec import DEFAULT_TIME_GRANULARITY, TimeDimensionSpec
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.time.granularity import ExpandedTimeGranularity
from metricflow_semantics.time.time_spine_source import TimeSpineSource

from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.sql.sql_exprs import (
from metricflow_semantics.sql.sql_exprs import (
SqlColumnReference,
SqlColumnReferenceExpression,
SqlDateTruncExpression,
SqlExpressionNode,
SqlExtractExpression,
SqlStringExpression,
)
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.time.granularity import ExpandedTimeGranularity
from metricflow_semantics.time.time_spine_source import TimeSpineSource

from metricflow.dataset.semantic_model_adapter import SemanticModelDataSet
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.sql.sql_plan import (
SqlSelectColumn,
SqlSelectStatementNode,
Expand Down
Loading
Loading