Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure column name is backticked during alter #861

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

- Replace array indexing with 'get' in split_part so as not to raise exception when indexing beyond bounds ([839](https://github.com/databricks/dbt-databricks/pull/839))
- Set queue enabled for Python notebook jobs ([856](https://github.com/databricks/dbt-databricks/pull/856))
- Ensure columns that are added get backticked ([859](https://github.com/databricks/dbt-databricks/pull/859))

### Under the Hood

Expand Down
16 changes: 15 additions & 1 deletion dbt/adapters/databricks/column.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Any, ClassVar, Optional

from dbt.adapters.databricks.utils import quote
from dbt.adapters.spark.column import SparkColumn


Expand Down Expand Up @@ -28,3 +29,16 @@ def data_type(self) -> str:

def __repr__(self) -> str:
return "<DatabricksColumn {} ({})>".format(self.name, self.data_type)

@staticmethod
def get_name(column: dict[str, Any]) -> str:
name = column["name"]
return quote(name) if column.get("quote", False) else name

@staticmethod
def format_remove_column_list(columns: list["DatabricksColumn"]) -> str:
return ", ".join([quote(c.name) for c in columns])

@staticmethod
def format_add_column_list(columns: list["DatabricksColumn"]) -> str:
return ", ".join([f"{quote(c.name)} {c.data_type}" for c in columns])
4 changes: 4 additions & 0 deletions dbt/adapters/databricks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,7 @@ def handle_missing_objects(exec: Callable[[], T], default: T) -> T:
if check_not_found_error(errmsg):
return default
raise e


def quote(name: str) -> str:
return f"`{name}`"
17 changes: 17 additions & 0 deletions dbt/include/databricks/macros/adapters/columns.sql
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,20 @@

{% do return(load_result('get_columns_comments_via_information_schema').table) %}
{% endmacro %}

{% macro databricks__alter_relation_add_remove_columns(relation, add_columns, remove_columns) %}
{% if remove_columns %}
{% if not relation.is_delta %}
{{ exceptions.raise_compiler_error('Delta format required for dropping columns from tables') }}
{% endif %}
{%- call statement('alter_relation_remove_columns') -%}
ALTER TABLE {{ relation }} DROP COLUMNS ({{ api.Column.format_remove_column_list(remove_columns) }})
{%- endcall -%}
{% endif %}

{% if add_columns %}
{%- call statement('alter_relation_add_columns') -%}
ALTER TABLE {{ relation }} ADD COLUMNS ({{ api.Column.format_add_column_list(add_columns) }})
{%- endcall -%}
{% endif %}
{% endmacro %}
23 changes: 3 additions & 20 deletions dbt/include/databricks/macros/adapters/persist_docs.sql
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
{% macro databricks__alter_column_comment(relation, column_dict) %}
{% if config.get('file_format', default='delta') in ['delta', 'hudi'] %}
{% for column_name in column_dict %}
{% set comment = column_dict[column_name]['description'] %}
{% for column in column_dict.values() %}
{% set comment = column['description'] %}
{% set escaped_comment = comment | replace('\'', '\\\'') %}
{% set comment_query %}
alter table {{ relation }} change column
{{ adapter.quote(column_name) if column_dict[column_name]['quote'] else column_name }}
comment '{{ escaped_comment }}';
alter table {{ relation }} change column {{ api.Column.get_name(column) }} comment '{{ escaped_comment }}';
{% endset %}
{% do run_query(comment_query) %}
{% endfor %}
Expand All @@ -30,18 +28,3 @@
{% do alter_column_comment(relation, columns_to_persist_docs) %}
{% endif %}
{% endmacro %}

{% macro get_column_comment_sql(column_name, column_dict) -%}
{% if column_name in column_dict and column_dict[column_name]["description"] -%}
{% set escaped_description = column_dict[column_name]["description"] | replace("'", "\\'") %}
{% set column_comment_clause = "comment '" ~ escaped_description ~ "'" %}
{%- endif -%}
{{ adapter.quote(column_name) }} {{ column_comment_clause }}
{% endmacro %}

{% macro get_persist_docs_column_list(model_columns, query_columns) %}
{% for column_name in query_columns %}
{{ get_column_comment_sql(column_name, model_columns) }}
{{- ", " if not loop.last else "" }}
{% endfor %}
{% endmacro %}
6 changes: 3 additions & 3 deletions dbt/include/databricks/macros/relations/constraints.sql
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
{% for column_name in column_names %}
{% set column = model.get('columns', {}).get(column_name) %}
{% if column %}
{% set quoted_name = adapter.quote(column['name']) if column['quote'] else column['name'] %}
{% set quoted_name = api.Column.get_name(column) %}
{% set stmt = "alter table " ~ relation ~ " change column " ~ quoted_name ~ " set not null " ~ (constraint.expression or "") ~ ";" %}
{% do statements.append(stmt) %}
{% else %}
Expand All @@ -154,7 +154,7 @@
{% if not column %}
{{ exceptions.warn('Invalid primary key column: ' ~ column_name) }}
{% else %}
{% set quoted_name = adapter.quote(column['name']) if column['quote'] else column['name'] %}
{% set quoted_name = api.Column.get_name(column) %}
{% do quoted_names.append(quoted_name) %}
{% endif %}
{% endfor %}
Expand Down Expand Up @@ -203,7 +203,7 @@
{% if not column %}
{{ exceptions.warn('Invalid foreign key column: ' ~ column_name) }}
{% else %}
{% set quoted_name = adapter.quote(column['name']) if column['quote'] else column['name'] %}
{% set quoted_name = api.Column.get_name(column) %}
{% do quoted_names.append(quoted_name) %}
{% endif %}
{% endfor %}
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ dependencies = [
"freezegun",
"mypy",
"pre-commit",
"ruff",
"types-requests",
"debugpy",
]
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/macros/relations/test_constraint_macros.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from unittest.mock import Mock

import pytest

from dbt.adapters.databricks.column import DatabricksColumn
from tests.unit.macros.base import MacroTestBase


Expand All @@ -16,6 +19,7 @@ def macro_folders_to_load(self) -> list:
def modify_context(self, default_context) -> None:
# Mock local_md5
default_context["local_md5"] = lambda s: f"hash({s})"
default_context["api"] = Mock(Column=DatabricksColumn)

def render_constraints(self, template, *args):
return self.run_macro(template, "databricks_constraints_to_dbt", *args)
Expand Down
40 changes: 40 additions & 0 deletions tests/unit/test_column.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from dbt.adapters.databricks.column import DatabricksColumn


Expand All @@ -24,3 +26,41 @@ def test_convert_table_stats_with_bytes_and_rows(self):
"stats:rows:label": "rows",
"stats:rows:value": 12345678,
}


class TestColumnStatics:
@pytest.mark.parametrize(
"column, expected",
[
({"name": "foo", "quote": True}, "`foo`"),
({"name": "foo", "quote": False}, "foo"),
({"name": "foo"}, "foo"),
],
)
def test_get_name(self, column, expected):
assert DatabricksColumn.get_name(column) == expected

@pytest.mark.parametrize(
"columns, expected",
[
([], ""),
([DatabricksColumn("foo", "string")], "`foo`"),
([DatabricksColumn("foo", "string"), DatabricksColumn("bar", "int")], "`foo`, `bar`"),
],
)
def test_format_remove_column_list(self, columns, expected):
assert DatabricksColumn.format_remove_column_list(columns) == expected

@pytest.mark.parametrize(
"columns, expected",
[
([], ""),
([DatabricksColumn("foo", "string")], "`foo` string"),
(
[DatabricksColumn("foo", "string"), DatabricksColumn("bar", "int")],
"`foo` string, `bar` int",
),
],
)
def test_format_add_column_list(self, columns, expected):
assert DatabricksColumn.format_add_column_list(columns) == expected
5 changes: 4 additions & 1 deletion tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dbt.adapters.databricks.utils import redact_credentials, remove_ansi
from dbt.adapters.databricks.utils import quote, redact_credentials, remove_ansi


class TestDatabricksUtils:
Expand Down Expand Up @@ -64,3 +64,6 @@ def test_remove_ansi(self):
72 # how to execute python model in notebook
"""
assert remove_ansi(test_string) == expected_string

def test_quote(self):
assert quote("table") == "`table`"
Loading