From 2b23a038d475aae8d8b1ecc46b4b6d2ae363e342 Mon Sep 17 00:00:00 2001 From: aliceliu Date: Sun, 25 Feb 2024 20:29:43 -0800 Subject: [PATCH] Add method to infer primary key of model (#9650) --- .../Under the Hood-20240223-115021.yaml | 6 + core/dbt/contracts/graph/nodes.py | 54 +++++ tests/unit/fixtures.py | 81 +++++++ tests/unit/test_contracts_graph_compiled.py | 70 +----- tests/unit/test_infer_primary_key.py | 200 ++++++++++++++++++ 5 files changed, 345 insertions(+), 66 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20240223-115021.yaml create mode 100644 tests/unit/fixtures.py create mode 100644 tests/unit/test_infer_primary_key.py diff --git a/.changes/unreleased/Under the Hood-20240223-115021.yaml b/.changes/unreleased/Under the Hood-20240223-115021.yaml new file mode 100644 index 00000000000..ccc1a381124 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240223-115021.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Implement primary key inference for model nodes +time: 2024-02-23T11:50:21.257494-08:00 +custom: + Author: aliceliu + Issue: "9652" diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 146f54f1a91..e0e26b24b56 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -492,6 +492,60 @@ def search_name(self): def materialization_enforces_constraints(self) -> bool: return self.config.materialized in ["table", "incremental"] + def infer_primary_key(self, data_tests: List["GenericTestNode"]) -> List[str]: + """ + Infers the columns that can be used as primary key of a model in the following order: + 1. Columns with primary key constraints + 2. Columns with unique and not_null data tests + 3. Columns with enabled unique or dbt_utils.unique_combination_of_columns data tests + 4. Columns with disabled unique or dbt_utils.unique_combination_of_columns data tests + """ + for constraint in self.constraints: + if constraint.type == ConstraintType.primary_key: + return constraint.columns + + for column, column_info in self.columns.items(): + for column_constraint in column_info.constraints: + if column_constraint.type == ConstraintType.primary_key: + return [column] + + columns_with_enabled_unique_tests = set() + columns_with_disabled_unique_tests = set() + columns_with_not_null_tests = set() + for test in data_tests: + columns = [] + if "column_name" in test.test_metadata.kwargs: + columns = [test.test_metadata.kwargs["column_name"]] + elif "combination_of_columns" in test.test_metadata.kwargs: + columns = test.test_metadata.kwargs["combination_of_columns"] + + for column in columns: + if test.test_metadata.name in ["unique", "unique_combination_of_columns"]: + if test.config.enabled: + columns_with_enabled_unique_tests.add(column) + else: + columns_with_disabled_unique_tests.add(column) + elif test.test_metadata.name == "not_null": + columns_with_not_null_tests.add(column) + + columns_with_unique_and_not_null_tests = [] + for column in columns_with_not_null_tests: + if ( + column in columns_with_enabled_unique_tests + or column in columns_with_disabled_unique_tests + ): + columns_with_unique_and_not_null_tests.append(column) + if columns_with_unique_and_not_null_tests: + return columns_with_unique_and_not_null_tests + + if columns_with_enabled_unique_tests: + return list(columns_with_enabled_unique_tests) + + if columns_with_disabled_unique_tests: + return list(columns_with_disabled_unique_tests) + + return [] + def same_contents(self, old, adapter_type) -> bool: return super().same_contents(old, adapter_type) and self.same_ref_representation(old) diff --git a/tests/unit/fixtures.py b/tests/unit/fixtures.py new file mode 100644 index 00000000000..6b4945911d3 --- /dev/null +++ b/tests/unit/fixtures.py @@ -0,0 +1,81 @@ +from dbt.contracts.files import FileHash +from dbt.contracts.graph.nodes import ( + DependsOn, + InjectedCTE, + ModelNode, + ModelConfig, + GenericTestNode, +) +from dbt.node_types import NodeType + +from dbt.artifacts.resources import Contract, TestConfig, TestMetadata + + +def model_node(): + return ModelNode( + package_name="test", + path="/root/models/foo.sql", + original_file_path="models/foo.sql", + language="sql", + raw_code='select * from {{ ref("other") }}', + name="foo", + resource_type=NodeType.Model, + unique_id="model.test.foo", + fqn=["test", "models", "foo"], + refs=[], + sources=[], + metrics=[], + depends_on=DependsOn(), + deferred=True, + description="", + database="test_db", + schema="test_schema", + alias="bar", + tags=[], + config=ModelConfig(), + contract=Contract(), + meta={}, + compiled=True, + extra_ctes=[InjectedCTE("whatever", "select * from other")], + extra_ctes_injected=True, + compiled_code="with whatever as (select * from other) select * from whatever", + checksum=FileHash.from_contents(""), + unrendered_config={}, + ) + + +def generic_test_node(): + return GenericTestNode( + package_name="test", + path="/root/x/path.sql", + original_file_path="/root/path.sql", + language="sql", + raw_code='select * from {{ ref("other") }}', + name="foo", + resource_type=NodeType.Test, + unique_id="model.test.foo", + fqn=["test", "models", "foo"], + refs=[], + sources=[], + metrics=[], + depends_on=DependsOn(), + deferred=False, + description="", + database="test_db", + schema="dbt_test__audit", + alias="bar", + tags=[], + config=TestConfig(severity="warn"), + contract=Contract(), + meta={}, + compiled=True, + extra_ctes=[InjectedCTE("whatever", "select * from other")], + extra_ctes_injected=True, + compiled_code="with whatever as (select * from other) select * from whatever", + column_name="id", + test_metadata=TestMetadata(namespace=None, name="foo", kwargs={}), + checksum=FileHash.from_contents(""), + unrendered_config={ + "severity": "warn", + }, + ) diff --git a/tests/unit/test_contracts_graph_compiled.py b/tests/unit/test_contracts_graph_compiled.py index 8c454d6a68a..96fbeb54090 100644 --- a/tests/unit/test_contracts_graph_compiled.py +++ b/tests/unit/test_contracts_graph_compiled.py @@ -8,11 +8,11 @@ from dbt.contracts.graph.nodes import ( DependsOn, GenericTestNode, - InjectedCTE, ModelNode, ModelConfig, ) -from dbt.artifacts.resources import Contract, TestConfig, TestMetadata +from dbt.artifacts.resources import TestConfig, TestMetadata +from tests.unit.fixtures import generic_test_node, model_node from dbt.node_types import NodeType from .utils import ( @@ -57,36 +57,7 @@ def basic_uncompiled_model(): @pytest.fixture def basic_compiled_model(): - return ModelNode( - package_name="test", - path="/root/models/foo.sql", - original_file_path="models/foo.sql", - language="sql", - raw_code='select * from {{ ref("other") }}', - name="foo", - resource_type=NodeType.Model, - unique_id="model.test.foo", - fqn=["test", "models", "foo"], - refs=[], - sources=[], - metrics=[], - depends_on=DependsOn(), - deferred=True, - description="", - database="test_db", - schema="test_schema", - alias="bar", - tags=[], - config=ModelConfig(), - contract=Contract(), - meta={}, - compiled=True, - extra_ctes=[InjectedCTE("whatever", "select * from other")], - extra_ctes_injected=True, - compiled_code="with whatever as (select * from other) select * from whatever", - checksum=FileHash.from_contents(""), - unrendered_config={}, - ) + return model_node() @pytest.fixture @@ -432,40 +403,7 @@ def basic_uncompiled_schema_test_node(): @pytest.fixture def basic_compiled_schema_test_node(): - return GenericTestNode( - package_name="test", - path="/root/x/path.sql", - original_file_path="/root/path.sql", - language="sql", - raw_code='select * from {{ ref("other") }}', - name="foo", - resource_type=NodeType.Test, - unique_id="model.test.foo", - fqn=["test", "models", "foo"], - refs=[], - sources=[], - metrics=[], - depends_on=DependsOn(), - deferred=False, - description="", - database="test_db", - schema="dbt_test__audit", - alias="bar", - tags=[], - config=TestConfig(severity="warn"), - contract=Contract(), - meta={}, - compiled=True, - extra_ctes=[InjectedCTE("whatever", "select * from other")], - extra_ctes_injected=True, - compiled_code="with whatever as (select * from other) select * from whatever", - column_name="id", - test_metadata=TestMetadata(namespace=None, name="foo", kwargs={}), - checksum=FileHash.from_contents(""), - unrendered_config={ - "severity": "warn", - }, - ) + return generic_test_node() @pytest.fixture diff --git a/tests/unit/test_infer_primary_key.py b/tests/unit/test_infer_primary_key.py new file mode 100644 index 00000000000..4afa2bf4652 --- /dev/null +++ b/tests/unit/test_infer_primary_key.py @@ -0,0 +1,200 @@ +from dbt_common.contracts.constraints import ( + ConstraintType, + ModelLevelConstraint, + ColumnLevelConstraint, +) + +from .fixtures import model_node, generic_test_node + +from dbt.contracts.graph.model_config import ( + TestConfig, +) +from dbt.contracts.graph.nodes import ( + ColumnInfo, +) +from dbt.artifacts.resources import TestMetadata + + +def test_no_primary_key(): + model = model_node() + assert model.infer_primary_key([]) == [] + + +def test_primary_key_model_constraint(): + model = model_node() + model.constraints = [ModelLevelConstraint(type=ConstraintType.primary_key, columns=["pk"])] + assertSameContents(model.infer_primary_key([]), ["pk"]) + + model.constraints = [ + ModelLevelConstraint(type=ConstraintType.primary_key, columns=["pk1", "pk2"]) + ] + assertSameContents(model.infer_primary_key([]), ["pk1", "pk2"]) + + +def test_primary_key_column_constraint(): + model = model_node() + model.columns = { + "column1": ColumnInfo( + "column1", constraints=[ColumnLevelConstraint(type=ConstraintType.primary_key)] + ), + "column2": ColumnInfo("column2"), + } + assertSameContents(model.infer_primary_key([]), ["column1"]) + + +def test_unique_non_null_single(): + model = model_node() + test1 = generic_test_node() + test1.test_metadata = TestMetadata(name="unique", kwargs={"column_name": "column1"}) + test2 = generic_test_node() + test2.test_metadata = TestMetadata(name="not_null", kwargs={"column_name": "column1"}) + test3 = generic_test_node() + test3.test_metadata = TestMetadata(name="unique", kwargs={"column_name": "column2"}) + tests = [test1, test2] + assertSameContents(model.infer_primary_key(tests), ["column1"]) + + +def test_unique_non_null_multiple(): + model = model_node() + tests = [] + for i in range(2): + for enabled in [True, False]: + test1 = generic_test_node() + test1.test_metadata = TestMetadata( + name="unique", kwargs={"column_name": "column" + str(i) + str(enabled)} + ) + test1.config = TestConfig(enabled=enabled) + test2 = generic_test_node() + test2.test_metadata = TestMetadata( + name="not_null", kwargs={"column_name": "column" + str(i) + str(enabled)} + ) + test2.config = TestConfig(enabled=enabled) + tests.extend([test1, test2]) + + assertSameContents( + model.infer_primary_key(tests), + ["column0True", "column1True", "column0False", "column1False"], + ) + + +def test_enabled_unique_single(): + model = model_node() + test1 = generic_test_node() + test1.test_metadata = TestMetadata(name="unique", kwargs={"column_name": "column1"}) + test2 = generic_test_node() + test2.config = TestConfig(enabled=False) + test2.test_metadata = TestMetadata(name="unique", kwargs={"column_name": "column3"}) + + tests = [test1, test2] + assertSameContents(model.infer_primary_key(tests), ["column1"]) + + +def test_enabled_unique_multiple(): + model = model_node() + test1 = generic_test_node() + test1.test_metadata = TestMetadata(name="unique", kwargs={"column_name": "column1"}) + test2 = generic_test_node() + test2.test_metadata = TestMetadata(name="unique", kwargs={"column_name": "column2 || column3"}) + + tests = [test1, test2] + assertSameContents(model.infer_primary_key(tests), ["column1", "column2 || column3"]) + + +def test_enabled_unique_combo_single(): + model = model_node() + test1 = generic_test_node() + test1.test_metadata = TestMetadata( + name="unique_combination_of_columns", + kwargs={"combination_of_columns": ["column1", "column2"]}, + ) + test2 = generic_test_node() + test2.config = TestConfig(enabled=False) + test2.test_metadata = TestMetadata( + name="unique_combination_of_columns", + kwargs={"combination_of_columns": ["column3", "column4"]}, + ) + + tests = [test1, test2] + assertSameContents(model.infer_primary_key(tests), ["column1", "column2"]) + + +def test_enabled_unique_combo_multiple(): + model = model_node() + test1 = generic_test_node() + test1.test_metadata = TestMetadata( + name="unique", kwargs={"combination_of_columns": ["column1", "column2"]} + ) + test2 = generic_test_node() + test2.test_metadata = TestMetadata( + name="unique", kwargs={"combination_of_columns": ["column3", "column4"]} + ) + + tests = [test1, test2] + assertSameContents( + model.infer_primary_key(tests), ["column1", "column2", "column3", "column4"] + ) + + +def test_disabled_unique_single(): + model = model_node() + test1 = generic_test_node() + test1.config = TestConfig(enabled=False) + test1.test_metadata = TestMetadata(name="unique", kwargs={"column_name": "column1"}) + test2 = generic_test_node() + test2.test_metadata = TestMetadata(name="not_null", kwargs={"column_name": "column2"}) + + tests = [test1, test2] + assertSameContents(model.infer_primary_key(tests), ["column1"]) + + +def test_disabled_unique_multiple(): + model = model_node() + test1 = generic_test_node() + test1.config = TestConfig(enabled=False) + test1.test_metadata = TestMetadata(name="unique", kwargs={"column_name": "column1"}) + test2 = generic_test_node() + test2.config = TestConfig(enabled=False) + test2.test_metadata = TestMetadata(name="unique", kwargs={"column_name": "column2 || column3"}) + + tests = [test1, test2] + assertSameContents(model.infer_primary_key(tests), ["column1", "column2 || column3"]) + + +def test_disabled_unique_combo_single(): + model = model_node() + test1 = generic_test_node() + test1.config = TestConfig(enabled=False) + test1.test_metadata = TestMetadata( + name="unique", kwargs={"combination_of_columns": ["column1", "column2"]} + ) + test2 = generic_test_node() + test2.config = TestConfig(enabled=False) + test2.test_metadata = TestMetadata( + name="random", kwargs={"combination_of_columns": ["column3", "column4"]} + ) + + tests = [test1, test2] + assertSameContents(model.infer_primary_key(tests), ["column1", "column2"]) + + +def test_disabled_unique_combo_multiple(): + model = model_node() + test1 = generic_test_node() + test1.config = TestConfig(enabled=False) + test1.test_metadata = TestMetadata( + name="unique", kwargs={"combination_of_columns": ["column1", "column2"]} + ) + test2 = generic_test_node() + test2.config = TestConfig(enabled=False) + test2.test_metadata = TestMetadata( + name="unique", kwargs={"combination_of_columns": ["column3", "column4"]} + ) + + tests = [test1, test2] + assertSameContents( + model.infer_primary_key(tests), ["column1", "column2", "column3", "column4"] + ) + + +def assertSameContents(list1, list2): + assert sorted(list1) == sorted(list2)