diff --git a/docs/concepts/models/model_kinds.md b/docs/concepts/models/model_kinds.md index a332cc9ed..dc35ee536 100644 --- a/docs/concepts/models/model_kinds.md +++ b/docs/concepts/models/model_kinds.md @@ -334,8 +334,10 @@ MODEL ( name db.employees, kind INCREMENTAL_BY_UNIQUE_KEY ( unique_key name, - when_matched WHEN MATCHED AND source.value IS NULL THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary), - WHEN MATCHED THEN UPDATE SET target.title = COALESCE(source.title, target.title) + when_matched ( + WHEN MATCHED AND source.value IS NULL THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary) + WHEN MATCHED THEN UPDATE SET target.title = COALESCE(source.title, target.title) + ) ) ); ``` diff --git a/setup.py b/setup.py index 0f5c3695e..cb6db7f56 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ "rich[jupyter]", "ruamel.yaml", "setuptools; python_version>='3.12'", - "sqlglot[rs]~=25.34.1", + "sqlglot[rs]~=26.0.0", "tenacity", ], extras_require={ diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 7f0625f6a..3efc81fa1 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -409,13 +409,12 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]: return None name = key.name.lower() - if name == "when_matched": - value: t.Optional[t.Union[exp.Expression, t.List[exp.Expression]]] = ( - self._parse_when_matched() # type: ignore - ) - elif name == "time_data_type": + if name == "time_data_type": # TODO: if we make *_data_type a convention to parse things into exp.DataType, we could make this more generic value = self._parse_types(schema=True) + elif name == "when_matched": + # Parentheses around the WHEN clauses can be used to disambiguate them from other properties + value = self._parse_wrapped(self._parse_when_matched, optional=True) elif self._match(TokenType.L_PAREN): value = self.expression(exp.Tuple, expressions=self._parse_csv(self._parse_equality)) self._match_r_paren() @@ -605,15 +604,11 @@ def _props_sql(self: Generator, expressions: t.List[exp.Expression]) -> str: size = len(expressions) for i, prop in enumerate(expressions): - value = prop.args.get("value") - if prop.name == "when_matched" and isinstance(value, list): - output_value = ", ".join(self.sql(v) for v in value) - else: - output_value = self.sql(prop, "value") - sql = self.indent(f"{prop.name} {output_value}") + sql = self.indent(f"{prop.name} {self.sql(prop, 'value')}") if i < size - 1: sql += "," + props.append(self.maybe_comment(sql, expression=prop)) return "\n".join(props) @@ -648,6 +643,15 @@ def _macro_func_sql(self: Generator, expression: MacroFunc) -> str: return self.maybe_comment(sql, expression) +def _whens_sql(self: Generator, expression: exp.Whens) -> str: + if isinstance(expression.parent, exp.Merge): + return self.whens_sql(expression) + + # If the `WHEN` clauses aren't part of a MERGE statement (e.g. they + # appear in the `MODEL` DDL), then we will wrap them with parentheses. + return self.wrap(self.expressions(expression, sep=" ", indent=False)) + + def _override(klass: t.Type[Tokenizer | Parser], func: t.Callable) -> None: name = func.__name__ setattr(klass, f"_{name}", getattr(klass, name)) @@ -901,6 +905,7 @@ def extend_sqlglot() -> None: ModelKind: _model_kind_sql, PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False), StagedFilePath: lambda self, e: self.table_sql(e), + exp.Whens: _whens_sql, } ) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 27de53a06..c9bdefbf5 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -1337,20 +1337,13 @@ def _merge( target_table: TableName, query: Query, on: exp.Expression, - match_expressions: t.List[exp.When], + whens: exp.Whens, ) -> None: this = exp.alias_(exp.to_table(target_table), alias=MERGE_TARGET_ALIAS, table=True) using = exp.alias_( exp.Subquery(this=query), alias=MERGE_SOURCE_ALIAS, copy=False, table=True ) - self.execute( - exp.Merge( - this=this, - using=using, - on=on, - expressions=match_expressions, - ) - ) + self.execute(exp.Merge(this=this, using=using, on=on, whens=whens)) def scd_type_2_by_time( self, @@ -1807,7 +1800,7 @@ def merge( source_table: QueryOrDF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]], unique_key: t.Sequence[exp.Expression], - when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None, + when_matched: t.Optional[exp.Whens] = None, ) -> None: source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( source_table, columns_to_types, target_table=target_table @@ -1820,17 +1813,23 @@ def merge( ) ) if not when_matched: - when_matched = exp.When( - matched=True, - source=False, - then=exp.Update( - expressions=[ - exp.column(col, MERGE_TARGET_ALIAS).eq(exp.column(col, MERGE_SOURCE_ALIAS)) - for col in columns_to_types - ], + when_matched = exp.Whens() + when_matched.append( + "expressions", + exp.When( + matched=True, + source=False, + then=exp.Update( + expressions=[ + exp.column(col, MERGE_TARGET_ALIAS).eq( + exp.column(col, MERGE_SOURCE_ALIAS) + ) + for col in columns_to_types + ], + ), ), ) - when_matched = ensure_list(when_matched) + when_not_matched = exp.When( matched=False, source=False, @@ -1841,14 +1840,15 @@ def merge( ), ), ) - match_expressions = when_matched + [when_not_matched] + when_matched.append("expressions", when_not_matched) + for source_query in source_queries: with source_query as query: self._merge( target_table=target_table, query=query, on=on, - match_expressions=match_expressions, + whens=when_matched, ) def rename_table( diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py index 40668ac59..d60f77056 100644 --- a/sqlmesh/core/engine_adapter/mixins.py +++ b/sqlmesh/core/engine_adapter/mixins.py @@ -29,7 +29,7 @@ def merge( source_table: QueryOrDF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]], unique_key: t.Sequence[exp.Expression], - when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None, + when_matched: t.Optional[exp.Whens] = None, ) -> None: logical_merge( self, @@ -104,7 +104,9 @@ def _insert_overwrite_by_condition( target_table=table_name, query=query, on=exp.false(), - match_expressions=[when_not_matched_by_source, when_not_matched_by_target], + whens=exp.Whens( + expressions=[when_not_matched_by_source, when_not_matched_by_target] + ), ) @@ -400,7 +402,7 @@ def logical_merge( source_table: QueryOrDF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]], unique_key: t.Sequence[exp.Expression], - when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None, + when_matched: t.Optional[exp.Whens] = None, ) -> None: """ Merge implementation for engine adapters that do not support merge natively. diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index f1677435a..4b0f9875f 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -106,7 +106,7 @@ def merge( source_table: QueryOrDF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]], unique_key: t.Sequence[exp.Expression], - when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None, + when_matched: t.Optional[exp.Whens] = None, ) -> None: # Merge isn't supported until Postgres 15 merge_impl = ( diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index 3200c3d8f..36a55fab4 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -5,7 +5,6 @@ from pydantic import Field from sqlglot import exp -from sqlglot.helper import ensure_list from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers from sqlglot.optimizer.simplify import gen @@ -423,48 +422,38 @@ class IncrementalByUniqueKeyKind(_IncrementalBy): ModelKindName.INCREMENTAL_BY_UNIQUE_KEY ) unique_key: SQLGlotListOfFields - when_matched: t.Optional[t.List[exp.When]] = None + when_matched: t.Optional[exp.Whens] = None batch_concurrency: t.Literal[1] = 1 @field_validator("when_matched", mode="before") @field_validator_v1_args def _when_matched_validator( cls, - v: t.Optional[t.Union[exp.When, str, t.List[exp.When], t.List[str]]], + v: t.Optional[t.Union[str, exp.Whens]], values: t.Dict[str, t.Any], - ) -> t.Optional[t.List[exp.When]]: + ) -> t.Optional[exp.Whens]: def replace_table_references(expression: exp.Expression) -> exp.Expression: - from sqlmesh.core.engine_adapter.base import ( - MERGE_SOURCE_ALIAS, - MERGE_TARGET_ALIAS, - ) + from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS if isinstance(expression, exp.Column): if expression.table.lower() == "target": - expression.set( - "table", - exp.to_identifier(MERGE_TARGET_ALIAS), - ) + expression.set("table", exp.to_identifier(MERGE_TARGET_ALIAS)) elif expression.table.lower() == "source": - expression.set( - "table", - exp.to_identifier(MERGE_SOURCE_ALIAS), - ) + expression.set("table", exp.to_identifier(MERGE_SOURCE_ALIAS)) + return expression - if not v: - return v # type: ignore - - result = [] - list_v = ensure_list(v) - for value in ensure_list(list_v): - if isinstance(value, str): - result.append( - t.cast(exp.When, d.parse_one(value, into=exp.When, dialect=get_dialect(values))) - ) - else: - result.append(t.cast(exp.When, value.transform(replace_table_references))) # type: ignore - return result + if v is None: + return v + if isinstance(v, str): + # Whens wrap the WHEN clauses, but the parentheses aren't parsed by sqlglot + v = v.strip() + if v.startswith("("): + v = v[1:-1] + + return t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=get_dialect(values))) + + return t.cast(exp.Whens, v.transform(replace_table_references)) @property def data_hash_values(self) -> t.List[t.Optional[str]]: diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index 2bd4fa13c..c8215df23 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -430,7 +430,7 @@ def managed_columns(self) -> t.Dict[str, exp.DataType]: return getattr(self.kind, "managed_columns", {}) @property - def when_matched(self) -> t.Optional[t.List[exp.When]]: + def when_matched(self) -> t.Optional[exp.Whens]: if isinstance(self.kind, IncrementalByUniqueKeyKind): return self.kind.when_matched return None diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index 67921c4c9..a52c74391 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -1014,20 +1014,26 @@ def test_merge_when_matched(make_mocked_engine_adapter: t.Callable, assert_exp_e "val": exp.DataType.build("int"), }, unique_key=[exp.to_identifier("ID", quoted=True)], - when_matched=exp.When( - matched=True, - source=False, - then=exp.Update( - expressions=[ - exp.column("val", "__MERGE_TARGET__").eq(exp.column("val", "__MERGE_SOURCE__")), - exp.column("ts", "__MERGE_TARGET__").eq( - exp.Coalesce( - this=exp.column("ts", "__MERGE_SOURCE__"), - expressions=[exp.column("ts", "__MERGE_TARGET__")], - ) + when_matched=exp.Whens( + expressions=[ + exp.When( + matched=True, + source=False, + then=exp.Update( + expressions=[ + exp.column("val", "__MERGE_TARGET__").eq( + exp.column("val", "__MERGE_SOURCE__") + ), + exp.column("ts", "__MERGE_TARGET__").eq( + exp.Coalesce( + this=exp.column("ts", "__MERGE_SOURCE__"), + expressions=[exp.column("ts", "__MERGE_TARGET__")], + ) + ), + ], ), - ], - ), + ) + ] ), ) @@ -1061,42 +1067,44 @@ def test_merge_when_matched_multiple(make_mocked_engine_adapter: t.Callable, ass "val": exp.DataType.build("int"), }, unique_key=[exp.to_identifier("ID", quoted=True)], - when_matched=[ - exp.When( - matched=True, - condition=exp.column("ID", "__MERGE_SOURCE__").eq(exp.Literal.number(1)), - then=exp.Update( - expressions=[ - exp.column("val", "__MERGE_TARGET__").eq( - exp.column("val", "__MERGE_SOURCE__") - ), - exp.column("ts", "__MERGE_TARGET__").eq( - exp.Coalesce( - this=exp.column("ts", "__MERGE_SOURCE__"), - expressions=[exp.column("ts", "__MERGE_TARGET__")], - ) - ), - ], + when_matched=exp.Whens( + expressions=[ + exp.When( + matched=True, + condition=exp.column("ID", "__MERGE_SOURCE__").eq(exp.Literal.number(1)), + then=exp.Update( + expressions=[ + exp.column("val", "__MERGE_TARGET__").eq( + exp.column("val", "__MERGE_SOURCE__") + ), + exp.column("ts", "__MERGE_TARGET__").eq( + exp.Coalesce( + this=exp.column("ts", "__MERGE_SOURCE__"), + expressions=[exp.column("ts", "__MERGE_TARGET__")], + ) + ), + ], + ), ), - ), - exp.When( - matched=True, - source=False, - then=exp.Update( - expressions=[ - exp.column("val", "__MERGE_TARGET__").eq( - exp.column("val", "__MERGE_SOURCE__") - ), - exp.column("ts", "__MERGE_TARGET__").eq( - exp.Coalesce( - this=exp.column("ts", "__MERGE_SOURCE__"), - expressions=[exp.column("ts", "__MERGE_TARGET__")], - ) - ), - ], + exp.When( + matched=True, + source=False, + then=exp.Update( + expressions=[ + exp.column("val", "__MERGE_TARGET__").eq( + exp.column("val", "__MERGE_SOURCE__") + ), + exp.column("ts", "__MERGE_TARGET__").eq( + exp.Coalesce( + this=exp.column("ts", "__MERGE_SOURCE__"), + expressions=[exp.column("ts", "__MERGE_TARGET__")], + ) + ), + ], + ), ), - ), - ], + ] + ), ) assert_exp_eq( diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 00e1142c2..b5a109c23 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -3944,15 +3944,54 @@ def test_when_matched(): """ ) - expected_when_matched = "WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary)" + expected_when_matched = "(WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary))" model = load_sql_based_model(expressions, dialect="hive") - assert len(model.kind.when_matched) == 1 - assert model.kind.when_matched[0].sql() == expected_when_matched + assert model.kind.when_matched.sql() == expected_when_matched model = SqlModel.parse_raw(model.json()) - assert len(model.kind.when_matched) == 1 - assert model.kind.when_matched[0].sql() == expected_when_matched + assert model.kind.when_matched.sql() == expected_when_matched + + expressions = d.parse( + """ + MODEL ( + name @{macro_val}.test, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key purchase_order_id, + when_matched ( + WHEN MATCHED AND source._operation = 1 THEN DELETE + WHEN MATCHED AND source._operation <> 1 THEN UPDATE SET target.purchase_order_id = 1 + ) + ) + ); + + SELECT + purchase_order_id + FROM @{macro_val}.upstream + """ + ) + + model = SqlModel.parse_raw(load_sql_based_model(expressions).json()) + assert d.format_model_expressions(model.render_definition()) == ( + """MODEL ( + name @{macro_val}.test, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key ("purchase_order_id"), + when_matched ( + WHEN MATCHED AND __MERGE_SOURCE__._operation = 1 THEN DELETE + WHEN MATCHED AND __MERGE_SOURCE__._operation <> 1 THEN UPDATE SET __MERGE_TARGET__.purchase_order_id = 1 + ), + batch_concurrency 1, + forward_only FALSE, + disable_restatement FALSE, + on_destructive_change 'ERROR' + ) +); + +SELECT + purchase_order_id +FROM @{macro_val}.upstream""" + ) def test_when_matched_multiple(): @@ -3977,14 +4016,16 @@ def test_when_matched_multiple(): ] model = load_sql_based_model(expressions, dialect="hive", variables={"schema": "db"}) - assert len(model.kind.when_matched) == 2 - assert model.kind.when_matched[0].sql() == expected_when_matched[0] - assert model.kind.when_matched[1].sql() == expected_when_matched[1] + whens = model.kind.when_matched + assert len(whens.expressions) == 2 + assert whens.expressions[0].sql() == expected_when_matched[0] + assert whens.expressions[1].sql() == expected_when_matched[1] model = SqlModel.parse_raw(model.json()) - assert len(model.kind.when_matched) == 2 - assert model.kind.when_matched[0].sql() == expected_when_matched[0] - assert model.kind.when_matched[1].sql() == expected_when_matched[1] + whens = model.kind.when_matched + assert len(whens.expressions) == 2 + assert whens.expressions[0].sql() == expected_when_matched[0] + assert whens.expressions[1].sql() == expected_when_matched[1] def test_default_catalog_sql(assert_exp_eq): @@ -5549,7 +5590,7 @@ def test_model_kind_to_expression(): .sql() == """INCREMENTAL_BY_UNIQUE_KEY ( unique_key ("a"), -when_matched ARRAY(WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b)), +when_matched (WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b)), batch_concurrency 1, forward_only FALSE, disable_restatement FALSE, @@ -5577,7 +5618,7 @@ def test_model_kind_to_expression(): .sql() == """INCREMENTAL_BY_UNIQUE_KEY ( unique_key ("a"), -when_matched ARRAY(WHEN MATCHED AND __MERGE_SOURCE__.x = 1 THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b), WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b)), +when_matched (WHEN MATCHED AND __MERGE_SOURCE__.x = 1 THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b) WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b)), batch_concurrency 1, forward_only FALSE, disable_restatement FALSE, diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 106ee7a0f..cd327578c 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -2018,25 +2018,27 @@ def test_create_incremental_by_unique_key_updated_at_exp(adapter_mock, make_snap "updated_at": exp.DataType.build("TIMESTAMP"), }, unique_key=[exp.to_column("id", quoted=True)], - when_matched=[ - exp.When( - matched=True, - source=False, - then=exp.Update( - expressions=[ - exp.column("name", MERGE_TARGET_ALIAS).eq( - exp.column("name", MERGE_SOURCE_ALIAS) - ), - exp.column("updated_at", MERGE_TARGET_ALIAS).eq( - exp.Coalesce( - this=exp.column("updated_at", MERGE_SOURCE_ALIAS), - expressions=[exp.column("updated_at", MERGE_TARGET_ALIAS)], - ) - ), - ], - ), - ) - ], + when_matched=exp.Whens( + expressions=[ + exp.When( + matched=True, + source=False, + then=exp.Update( + expressions=[ + exp.column("name", MERGE_TARGET_ALIAS).eq( + exp.column("name", MERGE_SOURCE_ALIAS) + ), + exp.column("updated_at", MERGE_TARGET_ALIAS).eq( + exp.Coalesce( + this=exp.column("updated_at", MERGE_SOURCE_ALIAS), + expressions=[exp.column("updated_at", MERGE_TARGET_ALIAS)], + ) + ), + ], + ), + ) + ] + ), ) @@ -2080,42 +2082,44 @@ def test_create_incremental_by_unique_key_multiple_updated_at_exp(adapter_mock, "updated_at": exp.DataType.build("TIMESTAMP"), }, unique_key=[exp.to_column("id", quoted=True)], - when_matched=[ - exp.When( - matched=True, - condition=exp.column("id", MERGE_SOURCE_ALIAS).eq(exp.Literal.number(1)), - then=exp.Update( - expressions=[ - exp.column("name", MERGE_TARGET_ALIAS).eq( - exp.column("name", MERGE_SOURCE_ALIAS) - ), - exp.column("updated_at", MERGE_TARGET_ALIAS).eq( - exp.Coalesce( - this=exp.column("updated_at", MERGE_SOURCE_ALIAS), - expressions=[exp.column("updated_at", MERGE_TARGET_ALIAS)], - ) - ), - ], + when_matched=exp.Whens( + expressions=[ + exp.When( + matched=True, + condition=exp.column("id", MERGE_SOURCE_ALIAS).eq(exp.Literal.number(1)), + then=exp.Update( + expressions=[ + exp.column("name", MERGE_TARGET_ALIAS).eq( + exp.column("name", MERGE_SOURCE_ALIAS) + ), + exp.column("updated_at", MERGE_TARGET_ALIAS).eq( + exp.Coalesce( + this=exp.column("updated_at", MERGE_SOURCE_ALIAS), + expressions=[exp.column("updated_at", MERGE_TARGET_ALIAS)], + ) + ), + ], + ), ), - ), - exp.When( - matched=True, - source=False, - then=exp.Update( - expressions=[ - exp.column("name", MERGE_TARGET_ALIAS).eq( - exp.column("name", MERGE_SOURCE_ALIAS) - ), - exp.column("updated_at", MERGE_TARGET_ALIAS).eq( - exp.Coalesce( - this=exp.column("updated_at", MERGE_SOURCE_ALIAS), - expressions=[exp.column("updated_at", MERGE_TARGET_ALIAS)], - ) - ), - ], + exp.When( + matched=True, + source=False, + then=exp.Update( + expressions=[ + exp.column("name", MERGE_TARGET_ALIAS).eq( + exp.column("name", MERGE_SOURCE_ALIAS) + ), + exp.column("updated_at", MERGE_TARGET_ALIAS).eq( + exp.Coalesce( + this=exp.column("updated_at", MERGE_SOURCE_ALIAS), + expressions=[exp.column("updated_at", MERGE_TARGET_ALIAS)], + ) + ), + ], + ), ), - ), - ], + ], + ), )