Skip to content

Commit

Permalink
Add method to infer primary key of model (#9650)
Browse files Browse the repository at this point in the history
  • Loading branch information
aliceliu authored Feb 26, 2024
1 parent d1ebf9d commit 2b23a03
Show file tree
Hide file tree
Showing 5 changed files with 345 additions and 66 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240223-115021.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Implement primary key inference for model nodes
time: 2024-02-23T11:50:21.257494-08:00
custom:
Author: aliceliu
Issue: "9652"
54 changes: 54 additions & 0 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,60 @@ def search_name(self):
def materialization_enforces_constraints(self) -> bool:
return self.config.materialized in ["table", "incremental"]

def infer_primary_key(self, data_tests: List["GenericTestNode"]) -> List[str]:
"""
Infers the columns that can be used as primary key of a model in the following order:
1. Columns with primary key constraints
2. Columns with unique and not_null data tests
3. Columns with enabled unique or dbt_utils.unique_combination_of_columns data tests
4. Columns with disabled unique or dbt_utils.unique_combination_of_columns data tests
"""
for constraint in self.constraints:
if constraint.type == ConstraintType.primary_key:
return constraint.columns

for column, column_info in self.columns.items():
for column_constraint in column_info.constraints:
if column_constraint.type == ConstraintType.primary_key:
return [column]

columns_with_enabled_unique_tests = set()
columns_with_disabled_unique_tests = set()
columns_with_not_null_tests = set()
for test in data_tests:
columns = []
if "column_name" in test.test_metadata.kwargs:
columns = [test.test_metadata.kwargs["column_name"]]
elif "combination_of_columns" in test.test_metadata.kwargs:
columns = test.test_metadata.kwargs["combination_of_columns"]

for column in columns:
if test.test_metadata.name in ["unique", "unique_combination_of_columns"]:
if test.config.enabled:
columns_with_enabled_unique_tests.add(column)
else:
columns_with_disabled_unique_tests.add(column)
elif test.test_metadata.name == "not_null":
columns_with_not_null_tests.add(column)

columns_with_unique_and_not_null_tests = []
for column in columns_with_not_null_tests:
if (
column in columns_with_enabled_unique_tests
or column in columns_with_disabled_unique_tests
):
columns_with_unique_and_not_null_tests.append(column)
if columns_with_unique_and_not_null_tests:
return columns_with_unique_and_not_null_tests

if columns_with_enabled_unique_tests:
return list(columns_with_enabled_unique_tests)

if columns_with_disabled_unique_tests:
return list(columns_with_disabled_unique_tests)

return []

def same_contents(self, old, adapter_type) -> bool:
return super().same_contents(old, adapter_type) and self.same_ref_representation(old)

Expand Down
81 changes: 81 additions & 0 deletions tests/unit/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from dbt.contracts.files import FileHash
from dbt.contracts.graph.nodes import (
DependsOn,
InjectedCTE,
ModelNode,
ModelConfig,
GenericTestNode,
)
from dbt.node_types import NodeType

from dbt.artifacts.resources import Contract, TestConfig, TestMetadata


def model_node():
return ModelNode(
package_name="test",
path="/root/models/foo.sql",
original_file_path="models/foo.sql",
language="sql",
raw_code='select * from {{ ref("other") }}',
name="foo",
resource_type=NodeType.Model,
unique_id="model.test.foo",
fqn=["test", "models", "foo"],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=True,
description="",
database="test_db",
schema="test_schema",
alias="bar",
tags=[],
config=ModelConfig(),
contract=Contract(),
meta={},
compiled=True,
extra_ctes=[InjectedCTE("whatever", "select * from other")],
extra_ctes_injected=True,
compiled_code="with whatever as (select * from other) select * from whatever",
checksum=FileHash.from_contents(""),
unrendered_config={},
)


def generic_test_node():
return GenericTestNode(
package_name="test",
path="/root/x/path.sql",
original_file_path="/root/path.sql",
language="sql",
raw_code='select * from {{ ref("other") }}',
name="foo",
resource_type=NodeType.Test,
unique_id="model.test.foo",
fqn=["test", "models", "foo"],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=False,
description="",
database="test_db",
schema="dbt_test__audit",
alias="bar",
tags=[],
config=TestConfig(severity="warn"),
contract=Contract(),
meta={},
compiled=True,
extra_ctes=[InjectedCTE("whatever", "select * from other")],
extra_ctes_injected=True,
compiled_code="with whatever as (select * from other) select * from whatever",
column_name="id",
test_metadata=TestMetadata(namespace=None, name="foo", kwargs={}),
checksum=FileHash.from_contents(""),
unrendered_config={
"severity": "warn",
},
)
70 changes: 4 additions & 66 deletions tests/unit/test_contracts_graph_compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from dbt.contracts.graph.nodes import (
DependsOn,
GenericTestNode,
InjectedCTE,
ModelNode,
ModelConfig,
)
from dbt.artifacts.resources import Contract, TestConfig, TestMetadata
from dbt.artifacts.resources import TestConfig, TestMetadata
from tests.unit.fixtures import generic_test_node, model_node
from dbt.node_types import NodeType

from .utils import (
Expand Down Expand Up @@ -57,36 +57,7 @@ def basic_uncompiled_model():

@pytest.fixture
def basic_compiled_model():
return ModelNode(
package_name="test",
path="/root/models/foo.sql",
original_file_path="models/foo.sql",
language="sql",
raw_code='select * from {{ ref("other") }}',
name="foo",
resource_type=NodeType.Model,
unique_id="model.test.foo",
fqn=["test", "models", "foo"],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=True,
description="",
database="test_db",
schema="test_schema",
alias="bar",
tags=[],
config=ModelConfig(),
contract=Contract(),
meta={},
compiled=True,
extra_ctes=[InjectedCTE("whatever", "select * from other")],
extra_ctes_injected=True,
compiled_code="with whatever as (select * from other) select * from whatever",
checksum=FileHash.from_contents(""),
unrendered_config={},
)
return model_node()


@pytest.fixture
Expand Down Expand Up @@ -432,40 +403,7 @@ def basic_uncompiled_schema_test_node():

@pytest.fixture
def basic_compiled_schema_test_node():
return GenericTestNode(
package_name="test",
path="/root/x/path.sql",
original_file_path="/root/path.sql",
language="sql",
raw_code='select * from {{ ref("other") }}',
name="foo",
resource_type=NodeType.Test,
unique_id="model.test.foo",
fqn=["test", "models", "foo"],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=False,
description="",
database="test_db",
schema="dbt_test__audit",
alias="bar",
tags=[],
config=TestConfig(severity="warn"),
contract=Contract(),
meta={},
compiled=True,
extra_ctes=[InjectedCTE("whatever", "select * from other")],
extra_ctes_injected=True,
compiled_code="with whatever as (select * from other) select * from whatever",
column_name="id",
test_metadata=TestMetadata(namespace=None, name="foo", kwargs={}),
checksum=FileHash.from_contents(""),
unrendered_config={
"severity": "warn",
},
)
return generic_test_node()


@pytest.fixture
Expand Down
Loading

0 comments on commit 2b23a03

Please sign in to comment.