Skip to content

Commit

Permalink
Create SqlAddTimeExpression
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Dec 11, 2024
1 parent 4bcec27 commit 94579a1
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 4 deletions.
1 change: 1 addition & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX = "ctt"
SQL_EXPR_DATE_TRUNC = "dt"
SQL_EXPR_SUBTRACT_TIME_INTERVAL_PREFIX = "sti"
SQL_EXPR_ADD_TIME_PREFIX = "ati"
SQL_EXPR_EXTRACT = "ex"
SQL_EXPR_RATIO_COMPUTATION = "rc"
SQL_EXPR_BETWEEN_PREFIX = "betw"
Expand Down
12 changes: 12 additions & 0 deletions metricflow/sql/render/big_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from metricflow.sql.sql_exprs import (
SqlAddTimeExpression,
SqlCastToTimestampExpression,
SqlDateTruncExpression,
SqlExtractExpression,
Expand Down Expand Up @@ -176,6 +177,17 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress
bind_parameter_set=column.bind_parameter_set,
)

@override
def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult:
"""Render time delta for BigQuery, which requires ISO prefixing for the WEEK granularity value."""
column = node.arg.accept(self)
count = node.count_expr.accept(self)

return SqlExpressionRenderResult(
sql=f"DATE_ADD(CAST({column.sql} AS {self.timestamp_data_type}), INTERVAL {count} {node.granularity.value})",
bind_parameter_set=column.bind_parameter_set,
)

@override
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult:
return SqlExpressionRenderResult(
Expand Down
17 changes: 17 additions & 0 deletions metricflow/sql/render/duckdb_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from metricflow.sql.sql_exprs import (
SqlAddTimeExpression,
SqlGenerateUuidExpression,
SqlPercentileExpression,
SqlPercentileFunctionType,
Expand Down Expand Up @@ -52,6 +53,22 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress
bind_parameter_set=arg_rendered.bind_parameter_set,
)

@override
def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult:
"""Render time delta expression for DuckDB, which requires slightly different syntax from other engines."""
arg_rendered = node.arg.accept(self)
count_rendered = node.count_expr.accept(self).sql

granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count_rendered = f"{count_rendered} * 3"

return SqlExpressionRenderResult(
sql=f"{arg_rendered.sql} + INTERVAL {count_rendered} {granularity.value}",
bind_parameter_set=arg_rendered.bind_parameter_set,
)

@override
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult:
return SqlExpressionRenderResult(
Expand Down
19 changes: 17 additions & 2 deletions metricflow/sql/render/expr_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from metricflow.sql.render.rendering_constants import SqlRenderingConstants
from metricflow.sql.sql_exprs import (
SqlAddTimeExpression,
SqlAggregateFunctionExpression,
SqlBetweenExpression,
SqlCastToTimestampExpression,
Expand Down Expand Up @@ -303,9 +304,9 @@ def render_date_part(self, date_part: DatePart) -> str:

return date_part.value

def visit_subtract_time_interval_expr(
def visit_subtract_time_interval_expr( # noqa: D102
self, node: SqlSubtractTimeIntervalExpression
) -> SqlExpressionRenderResult: # noqa: D102
) -> SqlExpressionRenderResult:
arg_rendered = node.arg.accept(self)

count = node.count
Expand All @@ -318,6 +319,20 @@ def visit_subtract_time_interval_expr(
bind_parameter_set=arg_rendered.bind_parameter_set,
)

def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: # noqa: D102
arg_rendered = node.arg.accept(self)
count_rendered = node.count_expr.accept(self).sql

granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count_rendered = f"{count_rendered} * 3"

return SqlExpressionRenderResult(
sql=f"DATEADD({granularity.value}, {count_rendered}, {arg_rendered.sql})",
bind_parameter_set=arg_rendered.bind_parameter_set,
)

def visit_ratio_computation_expr(self, node: SqlRatioComputationExpression) -> SqlExpressionRenderResult:
"""Render the ratio computation for a ratio metric.
Expand Down
17 changes: 17 additions & 0 deletions metricflow/sql/render/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from metricflow.sql.sql_exprs import (
SqlAddTimeExpression,
SqlGenerateUuidExpression,
SqlPercentileExpression,
SqlPercentileFunctionType,
Expand Down Expand Up @@ -54,6 +55,22 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress
bind_parameter_set=arg_rendered.bind_parameter_set,
)

@override
def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult:
"""Render time delta operations for PostgreSQL, which needs custom support for quarterly granularity."""
arg_rendered = node.arg.accept(self)
count_rendered = node.count_expr.accept(self).sql

granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count_rendered = f"{count_rendered} * 3"

return SqlExpressionRenderResult(
sql=f"{arg_rendered.sql} + MAKE_INTERVAL({granularity.value}s => {count_rendered})",
bind_parameter_set=arg_rendered.bind_parameter_set,
)

@override
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult:
return SqlExpressionRenderResult(
Expand Down
17 changes: 17 additions & 0 deletions metricflow/sql/render/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from metricflow.sql.sql_exprs import (
SqlAddTimeExpression,
SqlBetweenExpression,
SqlGenerateUuidExpression,
SqlPercentileExpression,
Expand Down Expand Up @@ -59,6 +60,22 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress
bind_parameter_set=arg_rendered.bind_parameter_set,
)

@override
def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult:
"""Render time delta for Trino, require granularity in quotes and function name change."""
arg_rendered = node.arg.accept(self)
count_rendered = node.count_expr.accept(self).sql

granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count_rendered = f"{count_rendered} * 3"

return SqlExpressionRenderResult(
sql=f"DATE_ADD('{granularity.value}', {count_rendered}, {arg_rendered.sql})",
bind_parameter_set=arg_rendered.bind_parameter_set,
)

@override
def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult:
"""Render a percentile expression for Trino."""
Expand Down
67 changes: 65 additions & 2 deletions metricflow/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,13 @@ def visit_extract_expr(self, node: SqlExtractExpression) -> VisitorOutputT: # n
pass

@abstractmethod
def visit_subtract_time_interval_expr(
def visit_subtract_time_interval_expr( # noqa: D102
self, node: SqlSubtractTimeIntervalExpression
) -> VisitorOutputT: # noqa: D102
) -> VisitorOutputT:
pass

@abstractmethod
def visit_add_time_expr(self, node: SqlAddTimeExpression) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
Expand Down Expand Up @@ -1316,6 +1320,65 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self.count == other.count and self.granularity == other.granularity and self._parents_match(other)


@dataclass(frozen=True, eq=False)
class SqlAddTimeExpression(SqlExpressionNode):
"""Add a time interval expr to a timestamp."""

arg: SqlExpressionNode
count_expr: SqlExpressionNode
granularity: TimeGranularity

@staticmethod
def create( # noqa: D102
arg: SqlExpressionNode,
count_expr: SqlExpressionNode,
granularity: TimeGranularity,
) -> SqlAddTimeExpression:
return SqlAddTimeExpression(
parent_nodes=(arg, count_expr),
arg=arg,
count_expr=count_expr,
granularity=granularity,
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.SQL_EXPR_ADD_TIME_PREFIX

@property
def requires_parenthesis(self) -> bool: # noqa: D102
return False

def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_add_time_expr(self)

@property
def description(self) -> str: # noqa: D102
return "Add time interval"

def rewrite( # noqa: D102
self,
column_replacements: Optional[SqlColumnReplacements] = None,
should_render_table_alias: Optional[bool] = None,
) -> SqlExpressionNode:
return SqlAddTimeExpression.create(
arg=self.arg.rewrite(column_replacements, should_render_table_alias),
count_expr=self.count_expr,
granularity=self.granularity,
)

@property
def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
return SqlExpressionTreeLineage.combine(
tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),)
)

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
if not isinstance(other, SqlAddTimeExpression):
return False
return self.count_expr == other.count_expr and self.granularity == other.granularity and self.arg == other.arg


@dataclass(frozen=True, eq=False)
class SqlCastToTimestampExpression(SqlExpressionNode):
"""Cast to the timestamp type like CAST('2020-01-01' AS TIMESTAMP)."""
Expand Down

0 comments on commit 94579a1

Please sign in to comment.