From 5066e76d7c6aa248b7a4f9d1863ecd3d46ac9bce Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:40:50 -0800 Subject: [PATCH] Ensure column name is backticked during alter (#861) --- CHANGELOG.md | 1 + dbt/adapters/databricks/column.py | 16 +++++++- dbt/adapters/databricks/utils.py | 4 ++ .../databricks/macros/adapters/columns.sql | 17 ++++++++ .../macros/adapters/persist_docs.sql | 23 ++--------- .../macros/relations/constraints.sql | 6 +-- pyproject.toml | 1 + .../relations/test_constraint_macros.py | 4 ++ tests/unit/test_column.py | 40 +++++++++++++++++++ tests/unit/test_utils.py | 5 ++- 10 files changed, 92 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dfce1d4e..f8e07129 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ - Replace array indexing with 'get' in split_part so as not to raise exception when indexing beyond bounds ([839](https://github.com/databricks/dbt-databricks/pull/839)) - Set queue enabled for Python notebook jobs ([856](https://github.com/databricks/dbt-databricks/pull/856)) +- Ensure columns that are added get backticked ([859](https://github.com/databricks/dbt-databricks/pull/859)) ### Under the Hood diff --git a/dbt/adapters/databricks/column.py b/dbt/adapters/databricks/column.py index 4d08ad4d..df2cdb2d 100644 --- a/dbt/adapters/databricks/column.py +++ b/dbt/adapters/databricks/column.py @@ -1,6 +1,7 @@ from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Any, ClassVar, Optional +from dbt.adapters.databricks.utils import quote from dbt.adapters.spark.column import SparkColumn @@ -28,3 +29,16 @@ def data_type(self) -> str: def __repr__(self) -> str: return "".format(self.name, self.data_type) + + @staticmethod + def get_name(column: dict[str, Any]) -> str: + name = column["name"] + return quote(name) if column.get("quote", False) else name + + @staticmethod + def format_remove_column_list(columns: list["DatabricksColumn"]) -> str: + return ", ".join([quote(c.name) for c in columns]) + + @staticmethod + def format_add_column_list(columns: list["DatabricksColumn"]) -> str: + return ", ".join([f"{quote(c.name)} {c.data_type}" for c in columns]) diff --git a/dbt/adapters/databricks/utils.py b/dbt/adapters/databricks/utils.py index 616e5b36..3dfd4096 100644 --- a/dbt/adapters/databricks/utils.py +++ b/dbt/adapters/databricks/utils.py @@ -73,3 +73,7 @@ def handle_missing_objects(exec: Callable[[], T], default: T) -> T: if check_not_found_error(errmsg): return default raise e + + +def quote(name: str) -> str: + return f"`{name}`" diff --git a/dbt/include/databricks/macros/adapters/columns.sql b/dbt/include/databricks/macros/adapters/columns.sql index e1fc1d11..7fe40e6f 100644 --- a/dbt/include/databricks/macros/adapters/columns.sql +++ b/dbt/include/databricks/macros/adapters/columns.sql @@ -25,3 +25,20 @@ {% do return(load_result('get_columns_comments_via_information_schema').table) %} {% endmacro %} + +{% macro databricks__alter_relation_add_remove_columns(relation, add_columns, remove_columns) %} + {% if remove_columns %} + {% if not relation.is_delta %} + {{ exceptions.raise_compiler_error('Delta format required for dropping columns from tables') }} + {% endif %} + {%- call statement('alter_relation_remove_columns') -%} + ALTER TABLE {{ relation }} DROP COLUMNS ({{ api.Column.format_remove_column_list(remove_columns) }}) + {%- endcall -%} + {% endif %} + + {% if add_columns %} + {%- call statement('alter_relation_add_columns') -%} + ALTER TABLE {{ relation }} ADD COLUMNS ({{ api.Column.format_add_column_list(add_columns) }}) + {%- endcall -%} + {% endif %} +{% endmacro %} \ No newline at end of file diff --git a/dbt/include/databricks/macros/adapters/persist_docs.sql b/dbt/include/databricks/macros/adapters/persist_docs.sql index 8e959a9f..873039e8 100644 --- a/dbt/include/databricks/macros/adapters/persist_docs.sql +++ b/dbt/include/databricks/macros/adapters/persist_docs.sql @@ -1,12 +1,10 @@ {% macro databricks__alter_column_comment(relation, column_dict) %} {% if config.get('file_format', default='delta') in ['delta', 'hudi'] %} - {% for column_name in column_dict %} - {% set comment = column_dict[column_name]['description'] %} + {% for column in column_dict.values() %} + {% set comment = column['description'] %} {% set escaped_comment = comment | replace('\'', '\\\'') %} {% set comment_query %} - alter table {{ relation }} change column - {{ adapter.quote(column_name) if column_dict[column_name]['quote'] else column_name }} - comment '{{ escaped_comment }}'; + alter table {{ relation }} change column {{ api.Column.get_name(column) }} comment '{{ escaped_comment }}'; {% endset %} {% do run_query(comment_query) %} {% endfor %} @@ -30,18 +28,3 @@ {% do alter_column_comment(relation, columns_to_persist_docs) %} {% endif %} {% endmacro %} - -{% macro get_column_comment_sql(column_name, column_dict) -%} - {% if column_name in column_dict and column_dict[column_name]["description"] -%} - {% set escaped_description = column_dict[column_name]["description"] | replace("'", "\\'") %} - {% set column_comment_clause = "comment '" ~ escaped_description ~ "'" %} - {%- endif -%} - {{ adapter.quote(column_name) }} {{ column_comment_clause }} -{% endmacro %} - -{% macro get_persist_docs_column_list(model_columns, query_columns) %} - {% for column_name in query_columns %} - {{ get_column_comment_sql(column_name, model_columns) }} - {{- ", " if not loop.last else "" }} - {% endfor %} -{% endmacro %} diff --git a/dbt/include/databricks/macros/relations/constraints.sql b/dbt/include/databricks/macros/relations/constraints.sql index bb77145f..68f3a44f 100644 --- a/dbt/include/databricks/macros/relations/constraints.sql +++ b/dbt/include/databricks/macros/relations/constraints.sql @@ -133,7 +133,7 @@ {% for column_name in column_names %} {% set column = model.get('columns', {}).get(column_name) %} {% if column %} - {% set quoted_name = adapter.quote(column['name']) if column['quote'] else column['name'] %} + {% set quoted_name = api.Column.get_name(column) %} {% set stmt = "alter table " ~ relation ~ " change column " ~ quoted_name ~ " set not null " ~ (constraint.expression or "") ~ ";" %} {% do statements.append(stmt) %} {% else %} @@ -154,7 +154,7 @@ {% if not column %} {{ exceptions.warn('Invalid primary key column: ' ~ column_name) }} {% else %} - {% set quoted_name = adapter.quote(column['name']) if column['quote'] else column['name'] %} + {% set quoted_name = api.Column.get_name(column) %} {% do quoted_names.append(quoted_name) %} {% endif %} {% endfor %} @@ -203,7 +203,7 @@ {% if not column %} {{ exceptions.warn('Invalid foreign key column: ' ~ column_name) }} {% else %} - {% set quoted_name = adapter.quote(column['name']) if column['quote'] else column['name'] %} + {% set quoted_name = api.Column.get_name(column) %} {% do quoted_names.append(quoted_name) %} {% endif %} {% endfor %} diff --git a/pyproject.toml b/pyproject.toml index f1f680ea..cce1a700 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ dependencies = [ "freezegun", "mypy", "pre-commit", + "ruff", "types-requests", "debugpy", ] diff --git a/tests/unit/macros/relations/test_constraint_macros.py b/tests/unit/macros/relations/test_constraint_macros.py index 351ca8cb..feac1797 100644 --- a/tests/unit/macros/relations/test_constraint_macros.py +++ b/tests/unit/macros/relations/test_constraint_macros.py @@ -1,5 +1,8 @@ +from unittest.mock import Mock + import pytest +from dbt.adapters.databricks.column import DatabricksColumn from tests.unit.macros.base import MacroTestBase @@ -16,6 +19,7 @@ def macro_folders_to_load(self) -> list: def modify_context(self, default_context) -> None: # Mock local_md5 default_context["local_md5"] = lambda s: f"hash({s})" + default_context["api"] = Mock(Column=DatabricksColumn) def render_constraints(self, template, *args): return self.run_macro(template, "databricks_constraints_to_dbt", *args) diff --git a/tests/unit/test_column.py b/tests/unit/test_column.py index 47a7fe1e..f0aa6562 100644 --- a/tests/unit/test_column.py +++ b/tests/unit/test_column.py @@ -1,3 +1,5 @@ +import pytest + from dbt.adapters.databricks.column import DatabricksColumn @@ -24,3 +26,41 @@ def test_convert_table_stats_with_bytes_and_rows(self): "stats:rows:label": "rows", "stats:rows:value": 12345678, } + + +class TestColumnStatics: + @pytest.mark.parametrize( + "column, expected", + [ + ({"name": "foo", "quote": True}, "`foo`"), + ({"name": "foo", "quote": False}, "foo"), + ({"name": "foo"}, "foo"), + ], + ) + def test_get_name(self, column, expected): + assert DatabricksColumn.get_name(column) == expected + + @pytest.mark.parametrize( + "columns, expected", + [ + ([], ""), + ([DatabricksColumn("foo", "string")], "`foo`"), + ([DatabricksColumn("foo", "string"), DatabricksColumn("bar", "int")], "`foo`, `bar`"), + ], + ) + def test_format_remove_column_list(self, columns, expected): + assert DatabricksColumn.format_remove_column_list(columns) == expected + + @pytest.mark.parametrize( + "columns, expected", + [ + ([], ""), + ([DatabricksColumn("foo", "string")], "`foo` string"), + ( + [DatabricksColumn("foo", "string"), DatabricksColumn("bar", "int")], + "`foo` string, `bar` int", + ), + ], + ) + def test_format_add_column_list(self, columns, expected): + assert DatabricksColumn.format_add_column_list(columns) == expected diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 81e77211..3e692a00 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,4 @@ -from dbt.adapters.databricks.utils import redact_credentials, remove_ansi +from dbt.adapters.databricks.utils import quote, redact_credentials, remove_ansi class TestDatabricksUtils: @@ -64,3 +64,6 @@ def test_remove_ansi(self): 72 # how to execute python model in notebook """ assert remove_ansi(test_string) == expected_string + + def test_quote(self): + assert quote("table") == "`table`"