diff --git a/sqlmesh/migrations/v0064_join_when_matched_strings.py b/sqlmesh/migrations/v0064_join_when_matched_strings.py new file mode 100644 index 000000000..324f6837a --- /dev/null +++ b/sqlmesh/migrations/v0064_join_when_matched_strings.py @@ -0,0 +1,86 @@ +"""Join list of `WHEN [NOT] MATCHED` strings into a single string.""" + +import json + +import pandas as pd +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +def migrate(state_sync, **kwargs): # type: ignore + engine_adapter = state_sync.engine_adapter + schema = state_sync.schema + snapshots_table = "_snapshots" + index_type = index_text_type(engine_adapter.dialect) + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + parsed_snapshot = json.loads(snapshot) + + if "kind" in node: + kind = node["kind"] + if isinstance(when_matched := kind.get("when_matched"), list): + kind["when_matched"] = " ".join(when_matched) + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + blob_type = blob_text_type(engine_adapter.dialect) + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + }, + )