From 1db231a196b21bd8621ee69d921e6fff955c60bf Mon Sep 17 00:00:00 2001 From: tlento Date: Mon, 2 Oct 2023 16:56:04 -0700 Subject: [PATCH 1/3] Remove unused grain_to_date from SqlTimeDeltaExpression The original implementation of SqlTimeDeltaExpression had a custom override for grain_to_date intended to support cumulative metrics. This parameter would change the expression from a time delta to a date_trunc. At some point we cleaned up the callsites to invoke date_trunc directly instead of offloading this work to an overloaded class. This commit simply removes the confusing parameter in order to streamline the expression rendering for time delta operations. --- metricflow/sql/render/big_query.py | 10 ---------- metricflow/sql/render/duckdb_renderer.py | 5 ----- metricflow/sql/render/expr_renderer.py | 5 ----- metricflow/sql/render/postgres.py | 5 ----- metricflow/sql/sql_exprs.py | 14 +------------- 5 files changed, 1 insertion(+), 38 deletions(-) diff --git a/metricflow/sql/render/big_query.py b/metricflow/sql/render/big_query.py index 4dcb42aaa3..b059eb6f1c 100644 --- a/metricflow/sql/render/big_query.py +++ b/metricflow/sql/render/big_query.py @@ -145,16 +145,6 @@ def render_date_part(self, date_part: DatePart) -> str: def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: """Render time delta for BigQuery, which requires ISO prefixing for the WEEK granularity value.""" column = node.arg.accept(self) - if node.grain_to_date: - granularity = node.granularity - if granularity == TimeGranularity.WEEK: - granularity_value = "ISO" + granularity.value.upper() - else: - granularity_value = granularity.value - return SqlExpressionRenderResult( - sql=f"DATE_TRUNC({column.sql}, {granularity_value})", - bind_parameters=column.bind_parameters, - ) return SqlExpressionRenderResult( sql=f"DATE_SUB(CAST({column.sql} AS {self.timestamp_data_type}), INTERVAL {node.count} {node.granularity.value})", diff --git a/metricflow/sql/render/duckdb_renderer.py b/metricflow/sql/render/duckdb_renderer.py index 7158d84de6..eb0815429e 100644 --- a/metricflow/sql/render/duckdb_renderer.py +++ b/metricflow/sql/render/duckdb_renderer.py @@ -37,11 +37,6 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: """Render time delta expression for DuckDB, which requires slightly different syntax from other engines.""" arg_rendered = node.arg.accept(self) - if node.grain_to_date: - return SqlExpressionRenderResult( - sql=f"DATE_TRUNC('{node.granularity.value}', {arg_rendered.sql}::timestamp)", - bind_parameters=arg_rendered.bind_parameters, - ) count = node.count granularity = node.granularity diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index 9a5e58a465..d0328f6e92 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -283,11 +283,6 @@ def render_date_part(self, date_part: DatePart) -> str: def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: # noqa: D arg_rendered = node.arg.accept(self) - if node.grain_to_date: - return SqlExpressionRenderResult( - sql=f"DATE_TRUNC('{node.granularity.value}', {arg_rendered.sql}::timestamp)", - bind_parameters=arg_rendered.bind_parameters, - ) count = node.count granularity = node.granularity diff --git a/metricflow/sql/render/postgres.py b/metricflow/sql/render/postgres.py index f77708d7fa..511dd2a214 100644 --- a/metricflow/sql/render/postgres.py +++ b/metricflow/sql/render/postgres.py @@ -40,11 +40,6 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: """Render time delta operations for PostgreSQL, which needs custom support for quarterly granularity.""" arg_rendered = node.arg.accept(self) - if node.grain_to_date: - return SqlExpressionRenderResult( - sql=f"DATE_TRUNC('{node.granularity.value}', {arg_rendered.sql}::timestamp)", - bind_parameters=arg_rendered.bind_parameters, - ) count = node.count granularity = node.granularity diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index 8ab96a29e9..a2ac98d825 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -1251,13 +1251,11 @@ def __init__( # noqa: D arg: SqlExpressionNode, count: int, granularity: TimeGranularity, - grain_to_date: Optional[TimeGranularity] = None, ) -> None: super().__init__(node_id=self.create_unique_id(), parent_nodes=[arg]) self._count = count self._time_granularity = granularity self._arg = arg - self._grain_to_date = grain_to_date @classmethod def id_prefix(cls) -> str: # noqa: D @@ -1278,10 +1276,6 @@ def description(self) -> str: # noqa: D def arg(self) -> SqlExpressionNode: # noqa: D return self._arg - @property - def grain_to_date(self) -> Optional[TimeGranularity]: # noqa: D - return self._grain_to_date - @property def count(self) -> int: # noqa: D return self._count @@ -1299,7 +1293,6 @@ def rewrite( # noqa: D arg=self.arg.rewrite(column_replacements, should_render_table_alias), count=self.count, granularity=self.granularity, - grain_to_date=self.grain_to_date, ) @property @@ -1311,12 +1304,7 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D def matches(self, other: SqlExpressionNode) -> bool: # noqa: D if not isinstance(other, SqlTimeDeltaExpression): return False - return ( - self.count == other.count - and self.granularity == other.granularity - and self.grain_to_date == other.grain_to_date - and self._parents_match(other) - ) + return self.count == other.count and self.granularity == other.granularity and self._parents_match(other) class SqlCastToTimestampExpression(SqlExpressionNode): From 2fdc42a464bf1e0b3bb94442f766191df05a9fb1 Mon Sep 17 00:00:00 2001 From: tlento Date: Tue, 3 Oct 2023 11:14:35 -0700 Subject: [PATCH 2/3] Rename SqlTimeDeltaExpression to SqlSubtractTimeIntervalExpression The original name for this class was misleading, as a time delta is effectively an interval computed as a difference. This suggested either that the rendering should be a date_diff to produce an interval between two timestamps, or that it should be an interval expression for use in some other operation. In reality, this class provides the functionality of a date_subtract, where we subtract an interval value from a given input timestamp. --- metricflow/plan_conversion/sql_join_builder.py | 6 +++--- metricflow/sql/render/big_query.py | 4 ++-- metricflow/sql/render/duckdb_renderer.py | 4 ++-- metricflow/sql/render/expr_renderer.py | 4 ++-- metricflow/sql/render/postgres.py | 4 ++-- metricflow/sql/sql_exprs.py | 16 +++++++++++----- .../test/integration/test_configured_cases.py | 4 ++-- 7 files changed, 24 insertions(+), 18 deletions(-) diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index 92ca575721..20b98d6ab5 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -15,7 +15,7 @@ SqlIsNullExpression, SqlLogicalExpression, SqlLogicalOperator, - SqlTimeDeltaExpression, + SqlSubtractTimeIntervalExpression, ) from metricflow.sql.sql_plan import SqlExpressionNode, SqlJoinDescription, SqlJoinType, SqlSelectStatementNode @@ -441,7 +441,7 @@ def make_cumulative_metric_time_range_join_description( start_of_range_comparison_expr = SqlComparisonExpression( left_expr=metric_time_column_expr, comparison=SqlComparison.GREATER_THAN, - right_expr=SqlTimeDeltaExpression( + right_expr=SqlSubtractTimeIntervalExpression( arg=time_spine_column_expr, count=node.window.count, granularity=node.window.granularity, @@ -481,7 +481,7 @@ def make_join_to_time_spine_join_description( col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=metric_time_dimension_column_name) ) if node.offset_window: - left_expr = SqlTimeDeltaExpression( + left_expr = SqlSubtractTimeIntervalExpression( arg=left_expr, count=node.offset_window.count, granularity=node.offset_window.granularity ) elif node.offset_to_grain: diff --git a/metricflow/sql/render/big_query.py b/metricflow/sql/render/big_query.py index b059eb6f1c..aa360e154a 100644 --- a/metricflow/sql/render/big_query.py +++ b/metricflow/sql/render/big_query.py @@ -21,7 +21,7 @@ SqlGenerateUuidExpression, SqlPercentileExpression, SqlPercentileFunctionType, - SqlTimeDeltaExpression, + SqlSubtractTimeIntervalExpression, ) from metricflow.sql.sql_plan import SqlSelectColumn from metricflow.time.date_part import DatePart @@ -142,7 +142,7 @@ def render_date_part(self, date_part: DatePart) -> str: return super().render_date_part(date_part) @override - def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: + def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: """Render time delta for BigQuery, which requires ISO prefixing for the WEEK granularity value.""" column = node.arg.accept(self) diff --git a/metricflow/sql/render/duckdb_renderer.py b/metricflow/sql/render/duckdb_renderer.py index eb0815429e..451fb05398 100644 --- a/metricflow/sql/render/duckdb_renderer.py +++ b/metricflow/sql/render/duckdb_renderer.py @@ -17,7 +17,7 @@ SqlGenerateUuidExpression, SqlPercentileExpression, SqlPercentileFunctionType, - SqlTimeDeltaExpression, + SqlSubtractTimeIntervalExpression, ) @@ -34,7 +34,7 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio } @override - def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: + def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: """Render time delta expression for DuckDB, which requires slightly different syntax from other engines.""" arg_rendered = node.arg.accept(self) diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index d0328f6e92..84df8d2395 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -33,7 +33,7 @@ SqlRatioComputationExpression, SqlStringExpression, SqlStringLiteralExpression, - SqlTimeDeltaExpression, + SqlSubtractTimeIntervalExpression, SqlWindowFunctionExpression, ) from metricflow.sql.sql_plan import SqlSelectColumn @@ -281,7 +281,7 @@ def render_date_part(self, date_part: DatePart) -> str: """Render DATE PART for an EXTRACT expression.""" return date_part.value - def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: # noqa: D + def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: # noqa: D arg_rendered = node.arg.accept(self) count = node.count diff --git a/metricflow/sql/render/postgres.py b/metricflow/sql/render/postgres.py index 511dd2a214..9ffc08dbf1 100644 --- a/metricflow/sql/render/postgres.py +++ b/metricflow/sql/render/postgres.py @@ -18,7 +18,7 @@ SqlGenerateUuidExpression, SqlPercentileExpression, SqlPercentileFunctionType, - SqlTimeDeltaExpression, + SqlSubtractTimeIntervalExpression, ) @@ -37,7 +37,7 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio return {SqlPercentileFunctionType.CONTINUOUS, SqlPercentileFunctionType.DISCRETE} @override - def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: + def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: """Render time delta operations for PostgreSQL, which needs custom support for quarterly granularity.""" arg_rendered = node.arg.accept(self) diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index a2ac98d825..bb6c13b851 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -225,7 +225,7 @@ def visit_extract_expr(self, node: SqlExtractExpression) -> VisitorOutputT: # n pass @abstractmethod - def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> VisitorOutputT: # noqa: D + def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> VisitorOutputT: # noqa: D pass @abstractmethod @@ -1243,8 +1243,14 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D return self._parents_match(other) -class SqlTimeDeltaExpression(SqlExpressionNode): - """create time delta between eg `DATE_SUB(ds, 2, month)`.""" +class SqlSubtractTimeIntervalExpression(SqlExpressionNode): + """Represents an interval subtraction from a given timestamp. + + This node contains the information required to produce a SQL statement which subtracts an interval with the given + count and granularity (which together define the interval duration) from the input timestamp expression. The return + value from the SQL rendering for this expression should be a timestamp expression offset from the initial input + value. + """ def __init__( # noqa: D self, @@ -1289,7 +1295,7 @@ def rewrite( # noqa: D column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlTimeDeltaExpression( + return SqlSubtractTimeIntervalExpression( arg=self.arg.rewrite(column_replacements, should_render_table_alias), count=self.count, granularity=self.granularity, @@ -1302,7 +1308,7 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D ) def matches(self, other: SqlExpressionNode) -> bool: # noqa: D - if not isinstance(other, SqlTimeDeltaExpression): + if not isinstance(other, SqlSubtractTimeIntervalExpression): return False return self.count == other.count and self.granularity == other.granularity and self._parents_match(other) diff --git a/metricflow/test/integration/test_configured_cases.py b/metricflow/test/integration/test_configured_cases.py index eec106d555..95579a11d1 100644 --- a/metricflow/test/integration/test_configured_cases.py +++ b/metricflow/test/integration/test_configured_cases.py @@ -31,7 +31,7 @@ SqlPercentileExpressionArgument, SqlPercentileFunctionType, SqlStringExpression, - SqlTimeDeltaExpression, + SqlSubtractTimeIntervalExpression, ) from metricflow.test.compare_df import assert_dataframes_equal from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState @@ -85,7 +85,7 @@ def render_date_sub( granularity: TimeGranularity, ) -> str: """Renders a date subtract expression.""" - expr = SqlTimeDeltaExpression( + expr = SqlSubtractTimeIntervalExpression( arg=SqlColumnReferenceExpression(SqlColumnReference(table_alias, column_alias)), count=count, granularity=granularity, From c7b37c419465337435ac2334cc9986fd22b71a03 Mon Sep 17 00:00:00 2001 From: tlento Date: Tue, 3 Oct 2023 11:22:23 -0700 Subject: [PATCH 3/3] Fix copy/paste errors on sql expr IDs We had a couple of copy/paste issues with sql expression nodes identifying themselves as IS_NULL expressions. This tidies that up. --- metricflow/dag/id_generation.py | 1 + metricflow/sql/sql_exprs.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/metricflow/dag/id_generation.py b/metricflow/dag/id_generation.py index e08fca06de..1fd5e03581 100644 --- a/metricflow/dag/id_generation.py +++ b/metricflow/dag/id_generation.py @@ -35,6 +35,7 @@ SQL_EXPR_IS_NULL_PREFIX = "isn" SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX = "ctt" SQL_EXPR_DATE_TRUNC = "dt" +SQL_EXPR_SUBTRACT_TIME_INTERVAL_PREFIX = "sti" SQL_EXPR_EXTRACT = "ex" SQL_EXPR_RATIO_COMPUTATION = "rc" SQL_EXPR_BETWEEN_PREFIX = "betw" diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index bb6c13b851..2206691e7c 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -16,6 +16,7 @@ from metricflow.dag.id_generation import ( SQL_EXPR_BETWEEN_PREFIX, + SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX, SQL_EXPR_COLUMN_REFERENCE_ID_PREFIX, SQL_EXPR_COMPARISON_ID_PREFIX, SQL_EXPR_DATE_TRUNC, @@ -29,6 +30,7 @@ SQL_EXPR_RATIO_COMPUTATION, SQL_EXPR_STRING_ID_PREFIX, SQL_EXPR_STRING_LITERAL_PREFIX, + SQL_EXPR_SUBTRACT_TIME_INTERVAL_PREFIX, SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX, ) from metricflow.dag.mf_dag import DagNode, DisplayedProperty, NodeId @@ -1265,7 +1267,7 @@ def __init__( # noqa: D @classmethod def id_prefix(cls) -> str: # noqa: D - return SQL_EXPR_IS_NULL_PREFIX + return SQL_EXPR_SUBTRACT_TIME_INTERVAL_PREFIX @property def requires_parenthesis(self) -> bool: # noqa: D @@ -1321,7 +1323,7 @@ def __init__(self, arg: SqlExpressionNode) -> None: # noqa: D @classmethod def id_prefix(cls) -> str: # noqa: D - return SQL_EXPR_IS_NULL_PREFIX + return SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX @property def requires_parenthesis(self) -> bool: # noqa: D