Skip to content

Commit

Permalink
Ensure inferred primary_key is a List[str] (#10984)
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk authored Nov 6, 2024
1 parent 81067d4 commit e451a37
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 4 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20241106-144656.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: 'Ensure inferred primary_key is a List[str] with no null values '
time: 2024-11-06T14:46:56.652963-05:00
custom:
Author: michelleark
Issue: "10983"
17 changes: 13 additions & 4 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,11 +541,20 @@ def infer_primary_key(self, data_tests: List["GenericTestNode"]) -> List[str]:
columns_with_disabled_unique_tests = set()
columns_with_not_null_tests = set()
for test in data_tests:
columns = []
if "column_name" in test.test_metadata.kwargs:
columns: List[str] = []
# extract columns from test kwargs, ensuring columns is a List[str] given tests can have custom (user or pacakge-defined) kwarg types
if "column_name" in test.test_metadata.kwargs and isinstance(
test.test_metadata.kwargs["column_name"], str
):
columns = [test.test_metadata.kwargs["column_name"]]
elif "combination_of_columns" in test.test_metadata.kwargs:
columns = test.test_metadata.kwargs["combination_of_columns"]
elif "combination_of_columns" in test.test_metadata.kwargs and isinstance(
test.test_metadata.kwargs["combination_of_columns"], list
):
columns = [
column
for column in test.test_metadata.kwargs["combination_of_columns"]
if isinstance(column, str)
]

for column in columns:
if test.test_metadata.name in ["unique", "unique_combination_of_columns"]:
Expand Down
20 changes: 20 additions & 0 deletions tests/functional/primary_keys/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@
- unique
"""

invalid_model_unique_test = """
models:
- name: simple_model
data_tests:
- unique:
column_name: null
columns:
- name: id
"""

simple_model_disabled_unique_test = """
models:
- name: simple_model
Expand Down Expand Up @@ -40,6 +50,16 @@
combination_of_columns: [id, color]
"""

invalid_model_unique_combo_of_columns = """
models:
- name: simple_model
tests:
- dbt_utils.unique_combination_of_columns:
combination_of_columns: [null]
- dbt_utils.unique_combination_of_columns:
combination_of_columns: "test"
"""

simple_model_constraints = """
models:
- name: simple_model
Expand Down
56 changes: 56 additions & 0 deletions tests/functional/primary_keys/test_primary_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from dbt.tests.util import get_manifest, run_dbt
from tests.functional.primary_keys.fixtures import (
invalid_model_unique_combo_of_columns,
invalid_model_unique_test,
simple_model_constraints,
simple_model_disabled_unique_test,
simple_model_sql,
Expand Down Expand Up @@ -155,3 +157,57 @@ def test_versioned_simple_combo_of_columns(self, project):
manifest = get_manifest(project.project_root)
node = manifest.nodes["model.test.simple_model"]
assert node.primary_key == ["color", "id"]


class TestInvalidModelCombinationOfColumns:
@pytest.fixture(scope="class")
def packages(self):
return {
"packages": [
{
"git": "https://github.com/dbt-labs/dbt-utils.git",
"revision": "1.1.0",
},
]
}

@pytest.fixture(scope="class")
def models(self):
return {
"simple_model.sql": simple_model_sql,
"schema.yml": invalid_model_unique_combo_of_columns,
}

def test_invalid_combo_of_columns(self, project):
run_dbt(["deps"])
run_dbt(["run"])
manifest = get_manifest(project.project_root)
node = manifest.nodes["model.test.simple_model"]
assert node.primary_key == []


class TestInvalidModelUniqueTest:
@pytest.fixture(scope="class")
def packages(self):
return {
"packages": [
{
"git": "https://github.com/dbt-labs/dbt-utils.git",
"revision": "1.1.0",
},
]
}

@pytest.fixture(scope="class")
def models(self):
return {
"simple_model.sql": simple_model_sql,
"schema.yml": invalid_model_unique_test,
}

def test_invalid_combo_of_columns(self, project):
run_dbt(["deps"])
run_dbt(["run"])
manifest = get_manifest(project.project_root)
node = manifest.nodes["model.test.simple_model"]
assert node.primary_key == []

0 comments on commit e451a37

Please sign in to comment.