diff --git a/.changes/unreleased/Fixes-20241021-093047.yaml b/.changes/unreleased/Fixes-20241021-093047.yaml new file mode 100644 index 00000000000..4d691aa8837 --- /dev/null +++ b/.changes/unreleased/Fixes-20241021-093047.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: unit tests with versioned refs +time: 2024-10-21T09:30:47.949023898-03:00 +custom: + Author: devmessias + Issue: 10880 10528 10623 diff --git a/core/dbt/parser/unit_tests.py b/core/dbt/parser/unit_tests.py index c227319f7e9..38a9c81fb3d 100644 --- a/core/dbt/parser/unit_tests.py +++ b/core/dbt/parser/unit_tests.py @@ -110,7 +110,6 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): # unit_test_node now has a populated refs/sources self.unit_test_manifest.nodes[unit_test_node.unique_id] = unit_test_node - # Now create input_nodes for the test inputs """ given: @@ -132,7 +131,6 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): given.input, tested_node, test_case.name ) input_name = original_input_node.name - common_fields = { "resource_type": NodeType.Model, # root directory for input and output fixtures @@ -149,23 +147,25 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition): "name": input_name, "path": f"{input_name}.sql", } + resource_type = original_input_node.resource_type - if original_input_node.resource_type in ( + if resource_type in ( NodeType.Model, NodeType.Seed, NodeType.Snapshot, ): + input_node = ModelNode( **common_fields, defer_relation=original_input_node.defer_relation, ) - if ( - original_input_node.resource_type == NodeType.Model - and original_input_node.version - ): - input_node.version = original_input_node.version + if resource_type == NodeType.Model: + if original_input_node.version: + input_node.version = original_input_node.version + if original_input_node.latest_version: + input_node.latest_version = original_input_node.latest_version - elif original_input_node.resource_type == NodeType.Source: + elif resource_type == NodeType.Source: # We are reusing the database/schema/identifier from the original source, # but that shouldn't matter since this acts as an ephemeral model which just # wraps a CTE around the unit test node. diff --git a/tests/functional/unit_testing/test_unit_testing.py b/tests/functional/unit_testing/test_unit_testing.py index 160f528787d..cd9083686da 100644 --- a/tests/functional/unit_testing/test_unit_testing.py +++ b/tests/functional/unit_testing/test_unit_testing.py @@ -309,6 +309,146 @@ def test_basic(self, project): assert len(results) == 2 +schema_ref_with_versioned_model = """ +models: + - name: source + latest_version: {latest_version} + versions: + - v: 1 + - v: 2 + - name: model_to_test +unit_tests: + - name: ref_versioned + model: 'model_to_test' + given: + - input: {input} + rows: + - {{result: 3}} + expect: + rows: + - {{result: 3}} + +""" + + +class TestUnitTestRefWithVersion: + @pytest.fixture(scope="class") + def models(self): + return { + "model_to_test.sql": "select result from {{ ref('source')}}", + "source.sql": "select 2 as result", + "source_v2.sql": "select 2 as result", + "schema.yml": schema_ref_with_versioned_model.format( + **{"latest_version": 1, "input": "ref('source')"} + ), + } + + def test_basic(self, project): + results = run_dbt(["run"]) + + results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True) + assert len(results) == 1 + + +class TestUnitTestRefMissingVersionModel: + @pytest.fixture(scope="class") + def models(self): + return { + "model_to_test.sql": "select result from {{ ref('source')}}", + "source_v1.sql": "select 2 as result", + "source_v2.sql": "select 2 as result", + "schema.yml": schema_ref_with_versioned_model.format( + **{"latest_version": 1, "input": "ref('source', v=1)"} + ), + } + + def test_basic(self, project): + results = run_dbt(["run"]) + + results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True) + assert len(results) == 1 + + +class TestUnitTestRefWithMissingVersionRef: + @pytest.fixture(scope="class") + def models(self): + return { + "model_to_test.sql": "select result from {{ ref('source', v=1)}}", + "source_v1.sql": "select 2 as result", + "source_v2.sql": "select 2 as result", + "schema.yml": schema_ref_with_versioned_model.format( + **{"latest_version": 1, "input": "ref('source')"} + ), + } + + def test_basic(self, project): + results = run_dbt(["run"]) + + results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True) + assert len(results) == 1 + + +class TestUnitTestRefWithVersionLatestSecond: + @pytest.fixture(scope="class") + def models(self): + return { + "model_to_test.sql": "select result from {{ ref('source')}}", + "source_v1.sql": "select 2 as result", + "source_v2.sql": "select 2 as result", + "schema.yml": schema_ref_with_versioned_model.format( + **{"latest_version": 2, "input": "ref('source')"} + ), + } + + def test_basic(self, project): + results = run_dbt(["run"]) + + results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True) + assert len(results) == 1 + + +class TestUnitTestRefWithVersionMissingRefTest: + @pytest.fixture(scope="class") + def models(self): + return { + "model_to_test.sql": "select result from {{ ref('source', v=2)}}", + "source_v1.sql": "select 2 as result", + "source_v2.sql": "select 2 as result", + "schema.yml": schema_ref_with_versioned_model.format( + **{"latest_version": 1, "input": "ref('source')"} + ), + } + + def test_basic(self, project): + results = run_dbt(["run"]) + + assert len(results) == 3 + # run_dbt(["test", "--select", "model_to_test"], expect_pass=False) + exec_result, _ = run_dbt_and_capture( + ["test", "--select", "model_to_test"], expect_pass=False + ) + msg_error = exec_result[0].message + assert msg_error.lower().lstrip().startswith("compilation error") + + +class TestUnitTestRefWithVersionDiffLatest: + @pytest.fixture(scope="class") + def models(self): + return { + "model_to_test.sql": "select result from {{ ref('source', v=2)}}", + "source_v1.sql": "select 2 as result", + "source_v2.sql": "select 2 as result", + "schema.yml": schema_ref_with_versioned_model.format( + **{"latest_version": 1, "input": "ref('source', v=2)"} + ), + } + + def test_basic(self, project): + results = run_dbt(["run"]) + assert len(results) == 3 + run_dbt(["test", "--select", "model_to_test"], expect_pass=True) + + class TestUnitTestExplicitSeed: @pytest.fixture(scope="class") def seeds(self):