From 2f7e2a9ee927d51e5fda0d29a7d9a0bafcd7aa0c Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Tue, 10 Dec 2024 12:52:49 -0800 Subject: [PATCH] Create SqlAddTimeExpression --- .../metricflow_semantics/dag/id_prefix.py | 1 + metricflow/sql/render/big_query.py | 12 ++++ metricflow/sql/render/duckdb_renderer.py | 17 +++++ metricflow/sql/render/expr_renderer.py | 19 +++++- metricflow/sql/render/postgres.py | 17 +++++ metricflow/sql/render/trino.py | 17 +++++ metricflow/sql/sql_exprs.py | 67 ++++++++++++++++++- 7 files changed, 146 insertions(+), 4 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py index 2cf673bb22..5c995d698c 100644 --- a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py +++ b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py @@ -68,6 +68,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper): SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX = "ctt" SQL_EXPR_DATE_TRUNC = "dt" SQL_EXPR_SUBTRACT_TIME_INTERVAL_PREFIX = "sti" + SQL_EXPR_ADD_TIME_PREFIX = "ati" SQL_EXPR_EXTRACT = "ex" SQL_EXPR_RATIO_COMPUTATION = "rc" SQL_EXPR_BETWEEN_PREFIX = "betw" diff --git a/metricflow/sql/render/big_query.py b/metricflow/sql/render/big_query.py index 0f1f989dbd..a63b2d06c6 100644 --- a/metricflow/sql/render/big_query.py +++ b/metricflow/sql/render/big_query.py @@ -18,6 +18,7 @@ ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from metricflow.sql.sql_exprs import ( + SqlAddTimeExpression, SqlCastToTimestampExpression, SqlDateTruncExpression, SqlExtractExpression, @@ -176,6 +177,17 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress bind_parameter_set=column.bind_parameter_set, ) + @override + def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: + """Render time delta for BigQuery, which requires ISO prefixing for the WEEK granularity value.""" + column = node.arg.accept(self) + count = node.count_expr.accept(self) + + return SqlExpressionRenderResult( + sql=f"DATE_ADD(CAST({column.sql} AS {self.timestamp_data_type}), INTERVAL {count} {node.granularity.value})", + bind_parameter_set=column.bind_parameter_set, + ) + @override def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult: return SqlExpressionRenderResult( diff --git a/metricflow/sql/render/duckdb_renderer.py b/metricflow/sql/render/duckdb_renderer.py index ed69b980ff..833bbedcf9 100644 --- a/metricflow/sql/render/duckdb_renderer.py +++ b/metricflow/sql/render/duckdb_renderer.py @@ -15,6 +15,7 @@ ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from metricflow.sql.sql_exprs import ( + SqlAddTimeExpression, SqlGenerateUuidExpression, SqlPercentileExpression, SqlPercentileFunctionType, @@ -52,6 +53,22 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress bind_parameter_set=arg_rendered.bind_parameter_set, ) + @override + def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: + """Render time delta expression for DuckDB, which requires slightly different syntax from other engines.""" + arg_rendered = node.arg.accept(self) + count_rendered = node.count_expr.accept(self).sql + + granularity = node.granularity + if granularity == TimeGranularity.QUARTER: + granularity = TimeGranularity.MONTH + count_rendered = f"{count_rendered} * 3" + + return SqlExpressionRenderResult( + sql=f"{arg_rendered.sql} + INTERVAL {count_rendered} {granularity.value}", + bind_parameter_set=arg_rendered.bind_parameter_set, + ) + @override def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult: return SqlExpressionRenderResult( diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index f7e3fcf758..c54ecb8eea 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -16,6 +16,7 @@ from metricflow.sql.render.rendering_constants import SqlRenderingConstants from metricflow.sql.sql_exprs import ( + SqlAddTimeExpression, SqlAggregateFunctionExpression, SqlBetweenExpression, SqlCastToTimestampExpression, @@ -303,9 +304,9 @@ def render_date_part(self, date_part: DatePart) -> str: return date_part.value - def visit_subtract_time_interval_expr( + def visit_subtract_time_interval_expr( # noqa: D102 self, node: SqlSubtractTimeIntervalExpression - ) -> SqlExpressionRenderResult: # noqa: D102 + ) -> SqlExpressionRenderResult: arg_rendered = node.arg.accept(self) count = node.count @@ -318,6 +319,20 @@ def visit_subtract_time_interval_expr( bind_parameter_set=arg_rendered.bind_parameter_set, ) + def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: # noqa: D102 + arg_rendered = node.arg.accept(self) + count_rendered = node.count_expr.accept(self).sql + + granularity = node.granularity + if granularity == TimeGranularity.QUARTER: + granularity = TimeGranularity.MONTH + count_rendered = f"{count_rendered} * 3" + + return SqlExpressionRenderResult( + sql=f"DATEADD({granularity.value}, {count_rendered}, {arg_rendered.sql})", + bind_parameter_set=arg_rendered.bind_parameter_set, + ) + def visit_ratio_computation_expr(self, node: SqlRatioComputationExpression) -> SqlExpressionRenderResult: """Render the ratio computation for a ratio metric. diff --git a/metricflow/sql/render/postgres.py b/metricflow/sql/render/postgres.py index 121622c567..5652899748 100644 --- a/metricflow/sql/render/postgres.py +++ b/metricflow/sql/render/postgres.py @@ -16,6 +16,7 @@ ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from metricflow.sql.sql_exprs import ( + SqlAddTimeExpression, SqlGenerateUuidExpression, SqlPercentileExpression, SqlPercentileFunctionType, @@ -54,6 +55,22 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress bind_parameter_set=arg_rendered.bind_parameter_set, ) + @override + def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: + """Render time delta operations for PostgreSQL, which needs custom support for quarterly granularity.""" + arg_rendered = node.arg.accept(self) + count_rendered = node.count_expr.accept(self).sql + + granularity = node.granularity + if granularity == TimeGranularity.QUARTER: + granularity = TimeGranularity.MONTH + count_rendered = f"{count_rendered} * 3" + + return SqlExpressionRenderResult( + sql=f"{arg_rendered.sql} + MAKE_INTERVAL({granularity.value}s => {count_rendered})", + bind_parameter_set=arg_rendered.bind_parameter_set, + ) + @override def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult: return SqlExpressionRenderResult( diff --git a/metricflow/sql/render/trino.py b/metricflow/sql/render/trino.py index 155fb6e466..ac902aa176 100644 --- a/metricflow/sql/render/trino.py +++ b/metricflow/sql/render/trino.py @@ -17,6 +17,7 @@ ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from metricflow.sql.sql_exprs import ( + SqlAddTimeExpression, SqlBetweenExpression, SqlGenerateUuidExpression, SqlPercentileExpression, @@ -59,6 +60,22 @@ def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpress bind_parameter_set=arg_rendered.bind_parameter_set, ) + @override + def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: + """Render time delta for Trino, require granularity in quotes and function name change.""" + arg_rendered = node.arg.accept(self) + count_rendered = node.count_expr.accept(self).sql + + granularity = node.granularity + if granularity == TimeGranularity.QUARTER: + granularity = TimeGranularity.MONTH + count_rendered = f"{count_rendered} * 3" + + return SqlExpressionRenderResult( + sql=f"DATE_ADD('{granularity.value}', {count_rendered}, {arg_rendered.sql})", + bind_parameter_set=arg_rendered.bind_parameter_set, + ) + @override def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult: """Render a percentile expression for Trino.""" diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index 16671de048..15b7268c50 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -210,9 +210,13 @@ def visit_extract_expr(self, node: SqlExtractExpression) -> VisitorOutputT: # n pass @abstractmethod - def visit_subtract_time_interval_expr( + def visit_subtract_time_interval_expr( # noqa: D102 self, node: SqlSubtractTimeIntervalExpression - ) -> VisitorOutputT: # noqa: D102 + ) -> VisitorOutputT: + pass + + @abstractmethod + def visit_add_time_expr(self, node: SqlAddTimeExpression) -> VisitorOutputT: # noqa: D102 pass @abstractmethod @@ -1316,6 +1320,65 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.count == other.count and self.granularity == other.granularity and self._parents_match(other) +@dataclass(frozen=True, eq=False) +class SqlAddTimeExpression(SqlExpressionNode): + """Add a time interval expr to a timestamp.""" + + arg: SqlExpressionNode + count_expr: SqlExpressionNode + granularity: TimeGranularity + + @staticmethod + def create( # noqa: D102 + arg: SqlExpressionNode, + count_expr: SqlExpressionNode, + granularity: TimeGranularity, + ) -> SqlAddTimeExpression: + return SqlAddTimeExpression( + parent_nodes=(arg, count_expr), + arg=arg, + count_expr=count_expr, + granularity=granularity, + ) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.SQL_EXPR_ADD_TIME_PREFIX + + @property + def requires_parenthesis(self) -> bool: # noqa: D102 + return False + + def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_add_time_expr(self) + + @property + def description(self) -> str: # noqa: D102 + return "Add time interval" + + def rewrite( # noqa: D102 + self, + column_replacements: Optional[SqlColumnReplacements] = None, + should_render_table_alias: Optional[bool] = None, + ) -> SqlExpressionNode: + return SqlAddTimeExpression.create( + arg=self.arg.rewrite(column_replacements, should_render_table_alias), + count_expr=self.count_expr, + granularity=self.granularity, + ) + + @property + def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 + return SqlExpressionTreeLineage.combine( + tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) + ) + + def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 + if not isinstance(other, SqlAddTimeExpression): + return False + return self.count_expr == other.count_expr and self.granularity == other.granularity and self.arg == other.arg + + @dataclass(frozen=True, eq=False) class SqlCastToTimestampExpression(SqlExpressionNode): """Cast to the timestamp type like CAST('2020-01-01' AS TIMESTAMP)."""