Skip to content

Commit

Permalink
fix: unit tests with versioned refs (#10889)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmessias authored Nov 14, 2024
1 parent 2c43af8 commit 1625eb0
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 9 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20241021-093047.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: unit tests with versioned refs
time: 2024-10-21T09:30:47.949023898-03:00
custom:
Author: devmessias
Issue: 10880 10528 10623
18 changes: 9 additions & 9 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
# unit_test_node now has a populated refs/sources

self.unit_test_manifest.nodes[unit_test_node.unique_id] = unit_test_node

# Now create input_nodes for the test inputs
"""
given:
Expand All @@ -132,7 +131,6 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
given.input, tested_node, test_case.name
)
input_name = original_input_node.name

common_fields = {
"resource_type": NodeType.Model,
# root directory for input and output fixtures
Expand All @@ -149,23 +147,25 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
"name": input_name,
"path": f"{input_name}.sql",
}
resource_type = original_input_node.resource_type

if original_input_node.resource_type in (
if resource_type in (
NodeType.Model,
NodeType.Seed,
NodeType.Snapshot,
):

input_node = ModelNode(
**common_fields,
defer_relation=original_input_node.defer_relation,
)
if (
original_input_node.resource_type == NodeType.Model
and original_input_node.version
):
input_node.version = original_input_node.version
if resource_type == NodeType.Model:
if original_input_node.version:
input_node.version = original_input_node.version
if original_input_node.latest_version:
input_node.latest_version = original_input_node.latest_version

elif original_input_node.resource_type == NodeType.Source:
elif resource_type == NodeType.Source:
# We are reusing the database/schema/identifier from the original source,
# but that shouldn't matter since this acts as an ephemeral model which just
# wraps a CTE around the unit test node.
Expand Down
140 changes: 140 additions & 0 deletions tests/functional/unit_testing/test_unit_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,146 @@ def test_basic(self, project):
assert len(results) == 2


schema_ref_with_versioned_model = """
models:
- name: source
latest_version: {latest_version}
versions:
- v: 1
- v: 2
- name: model_to_test
unit_tests:
- name: ref_versioned
model: 'model_to_test'
given:
- input: {input}
rows:
- {{result: 3}}
expect:
rows:
- {{result: 3}}
"""


class TestUnitTestRefWithVersion:
@pytest.fixture(scope="class")
def models(self):
return {
"model_to_test.sql": "select result from {{ ref('source')}}",
"source.sql": "select 2 as result",
"source_v2.sql": "select 2 as result",
"schema.yml": schema_ref_with_versioned_model.format(
**{"latest_version": 1, "input": "ref('source')"}
),
}

def test_basic(self, project):
results = run_dbt(["run"])

results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True)
assert len(results) == 1


class TestUnitTestRefMissingVersionModel:
@pytest.fixture(scope="class")
def models(self):
return {
"model_to_test.sql": "select result from {{ ref('source')}}",
"source_v1.sql": "select 2 as result",
"source_v2.sql": "select 2 as result",
"schema.yml": schema_ref_with_versioned_model.format(
**{"latest_version": 1, "input": "ref('source', v=1)"}
),
}

def test_basic(self, project):
results = run_dbt(["run"])

results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True)
assert len(results) == 1


class TestUnitTestRefWithMissingVersionRef:
@pytest.fixture(scope="class")
def models(self):
return {
"model_to_test.sql": "select result from {{ ref('source', v=1)}}",
"source_v1.sql": "select 2 as result",
"source_v2.sql": "select 2 as result",
"schema.yml": schema_ref_with_versioned_model.format(
**{"latest_version": 1, "input": "ref('source')"}
),
}

def test_basic(self, project):
results = run_dbt(["run"])

results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True)
assert len(results) == 1


class TestUnitTestRefWithVersionLatestSecond:
@pytest.fixture(scope="class")
def models(self):
return {
"model_to_test.sql": "select result from {{ ref('source')}}",
"source_v1.sql": "select 2 as result",
"source_v2.sql": "select 2 as result",
"schema.yml": schema_ref_with_versioned_model.format(
**{"latest_version": 2, "input": "ref('source')"}
),
}

def test_basic(self, project):
results = run_dbt(["run"])

results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True)
assert len(results) == 1


class TestUnitTestRefWithVersionMissingRefTest:
@pytest.fixture(scope="class")
def models(self):
return {
"model_to_test.sql": "select result from {{ ref('source', v=2)}}",
"source_v1.sql": "select 2 as result",
"source_v2.sql": "select 2 as result",
"schema.yml": schema_ref_with_versioned_model.format(
**{"latest_version": 1, "input": "ref('source')"}
),
}

def test_basic(self, project):
results = run_dbt(["run"])

assert len(results) == 3
# run_dbt(["test", "--select", "model_to_test"], expect_pass=False)
exec_result, _ = run_dbt_and_capture(
["test", "--select", "model_to_test"], expect_pass=False
)
msg_error = exec_result[0].message
assert msg_error.lower().lstrip().startswith("compilation error")


class TestUnitTestRefWithVersionDiffLatest:
@pytest.fixture(scope="class")
def models(self):
return {
"model_to_test.sql": "select result from {{ ref('source', v=2)}}",
"source_v1.sql": "select 2 as result",
"source_v2.sql": "select 2 as result",
"schema.yml": schema_ref_with_versioned_model.format(
**{"latest_version": 1, "input": "ref('source', v=2)"}
),
}

def test_basic(self, project):
results = run_dbt(["run"])
assert len(results) == 3
run_dbt(["test", "--select", "model_to_test"], expect_pass=True)


class TestUnitTestExplicitSeed:
@pytest.fixture(scope="class")
def seeds(self):
Expand Down

0 comments on commit 1625eb0

Please sign in to comment.