From 998b50d873c67d4ab58cf49612c683b22410f1d3 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Tue, 10 Dec 2024 20:16:13 +0000 Subject: [PATCH] Fix(bigquery): Pass catalog when checking for clustering key changes --- .circleci/continue_config.yml | 16 ++++---- sqlmesh/core/engine_adapter/mixins.py | 10 ++++- tests/core/engine_adapter/test_bigquery.py | 43 ++++++++++++++++++++++ 3 files changed, 59 insertions(+), 10 deletions(-) diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index a23dc3dd5..3786c2d95 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -340,15 +340,15 @@ workflows: parameters: engine: - snowflake - - databricks - - redshift + #- databricks + #- redshift - bigquery - - clickhouse-cloud - - athena - filters: - branches: - only: - - main + #- clickhouse-cloud + #- athena + #filters: + # branches: + # only: + # - main - trigger_private_tests: requires: - style_and_slow_tests diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py index 40668ac59..fb78f4c36 100644 --- a/sqlmesh/core/engine_adapter/mixins.py +++ b/sqlmesh/core/engine_adapter/mixins.py @@ -9,6 +9,7 @@ from sqlmesh.core.engine_adapter.base import EngineAdapter from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, SourceQuery from sqlmesh.core.node import IntervalUnit +from sqlmesh.core.dialect import schema_ from sqlmesh.utils.errors import SQLMeshError if t.TYPE_CHECKING: @@ -356,10 +357,15 @@ def get_alter_expressions( current_table = exp.to_table(current_table_name) target_table = exp.to_table(target_table_name) + current_table_schema = schema_(current_table.db, catalog=current_table.catalog) + target_table_schema = schema_(target_table.db, catalog=target_table.catalog) + current_table_info = seq_get( - self.get_data_objects(current_table.db, {current_table.name}), 0 + self.get_data_objects(current_table_schema, {current_table.name}), 0 + ) + target_table_info = seq_get( + self.get_data_objects(target_table_schema, {target_table.name}), 0 ) - target_table_info = seq_get(self.get_data_objects(target_table.db, {target_table.name}), 0) if current_table_info and target_table_info: if target_table_info.is_clustered: diff --git a/tests/core/engine_adapter/test_bigquery.py b/tests/core/engine_adapter/test_bigquery.py index 24b729b6c..e9460d15b 100644 --- a/tests/core/engine_adapter/test_bigquery.py +++ b/tests/core/engine_adapter/test_bigquery.py @@ -18,6 +18,11 @@ pytestmark = [pytest.mark.bigquery, pytest.mark.engine] +@pytest.fixture +def adapter(make_mocked_engine_adapter: t.Callable) -> BigQueryEngineAdapter: + return make_mocked_engine_adapter(BigQueryEngineAdapter) + + def test_insert_overwrite_by_time_partition_query( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture ): @@ -893,3 +898,41 @@ def test_nested_fields_update(make_mocked_engine_adapter: t.Callable, mocker: Mo bigquery.SchemaField("details", "STRING", "REPEATED"), ] assert adapter._build_nested_fields(current_schema, new_nested_fields) == expected + + +def test_get_alter_expressions_includes_catalog( + adapter: BigQueryEngineAdapter, mocker: MockerFixture +): + adapter._default_catalog = "test_project" + + columns_mock = mocker.patch( + "sqlmesh.core.engine_adapter.bigquery.BigQueryEngineAdapter.columns" + ) + columns_mock.return_value = { + "a": exp.DataType.build("int"), + } + + get_data_objects_mock = mocker.patch( + "sqlmesh.core.engine_adapter.bigquery.BigQueryEngineAdapter.get_data_objects" + ) + get_data_objects_mock.return_value = [] + + adapter.get_alter_expressions("catalog1.foo.bar", "catalog2.bar.bing") + + assert get_data_objects_mock.call_count == 2 + + schema, tables = get_data_objects_mock.call_args_list[0][0] + assert isinstance(schema, exp.Table) + assert isinstance(tables, set) + assert schema.catalog == "catalog1" + assert schema.db == "foo" + assert schema.sql(dialect="bigquery") == "catalog1.foo" + assert tables == {"bar"} + + schema, tables = get_data_objects_mock.call_args_list[1][0] + assert isinstance(schema, exp.Table) + assert isinstance(tables, set) + assert schema.catalog == "catalog2" + assert schema.db == "bar" + assert schema.sql(dialect="bigquery") == "catalog2.bar" + assert tables == {"bing"}