Skip to content

Commit

Permalink
fix: unit tests with versioned refs
Browse files Browse the repository at this point in the history
missing latest prop
  • Loading branch information
devmessias committed Nov 11, 2024
1 parent 89caa33 commit 3bb686e
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 15 deletions.
16 changes: 9 additions & 7 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
# unit_test_node now has a populated refs/sources

self.unit_test_manifest.nodes[unit_test_node.unique_id] = unit_test_node

# Now create input_nodes for the test inputs
"""
given:
Expand All @@ -132,7 +131,6 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
given.input, tested_node, test_case.name
)
input_name = original_input_node.name

common_fields = {
"resource_type": NodeType.Model,
# root directory for input and output fixtures
Expand All @@ -149,21 +147,25 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
"name": input_name,
"path": f"{input_name}.sql",
}
resource_type = original_input_node.resource_type

if original_input_node.resource_type in (
if resource_type in (
NodeType.Model,
NodeType.Seed,
NodeType.Snapshot,
):

input_node = ModelNode(
**common_fields,
defer_relation=original_input_node.defer_relation,
)
if (
original_input_node.resource_type == NodeType.Model
and original_input_node.version
):

if not resource_type == NodeType.Model:
continue
if original_input_node.version:
input_node.version = original_input_node.version
if original_input_node.latest_version:
input_node.latest_version = original_input_node.latest_version

elif original_input_node.resource_type == NodeType.Source:
# We are reusing the database/schema/identifier from the original source,
Expand Down
133 changes: 125 additions & 8 deletions tests/functional/unit_testing/test_unit_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,22 +291,139 @@ def test_basic(self, project):
assert len(results) == 2


class TestUnitTestIncrementalModelWithVersion:
schema_ref_with_version = """
models:
- name: source
latest_version: {latest_version}
versions:
- v: 1
- v: 2
- name: model_to_test
unit_tests:
- name: ref_versioned
model: 'model_to_test'
given:
- input: {input}
rows:
- {{result: 3}}
expect:
rows:
- {{result: 3}}
"""


class TestUnitTestRefWithVersion:
@pytest.fixture(scope="class")
def models(self):
return {
"my_incremental_model.sql": my_incremental_model_sql,
"events.sql": event_sql,
"schema.yml": my_incremental_model_versioned_yml + test_my_model_incremental_yml_basic,
"model_to_test.sql": "select result from {{ ref('source')}}",
"source.sql": "select 2 as result",
"source_v2.sql": "select 2 as result",
"schema.yml": schema_ref_with_version.format(
**{"latest_version": 1, "input": "ref('source')"}
),
}

def test_basic(self, project):
results = run_dbt(["run"])
assert len(results) == 2

# Select by model name
results = run_dbt(["test", "--select", "my_incremental_model"], expect_pass=True)
assert len(results) == 2
results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True)
assert len(results) == 1


class TestUnitTestRefMissingVersionModel:
@pytest.fixture(scope="class")
def models(self):
return {
"model_to_test.sql": "select result from {{ ref('source')}}",
"source_v1.sql": "select 2 as result",
"source_v2.sql": "select 2 as result",
"schema.yml": schema_ref_with_version.format(
**{"latest_version": 1, "input": "ref('source', v=1)"}
),
}

def test_basic(self, project):
results = run_dbt(["run"])

results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True)
assert len(results) == 1


class TestUnitTestRefWithMissingVersionRef:
@pytest.fixture(scope="class")
def models(self):
return {
"model_to_test.sql": "select result from {{ ref('source', v=1)}}",
"source_v1.sql": "select 2 as result",
"source_v2.sql": "select 2 as result",
"schema.yml": schema_ref_with_version.format(
**{"latest_version": 1, "input": "ref('source')"}
),
}

def test_basic(self, project):
results = run_dbt(["run"])

results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True)
assert len(results) == 1


class TestUnitTestRefWithVersionLatestSecond:
@pytest.fixture(scope="class")
def models(self):
return {
"model_to_test.sql": "select result from {{ ref('source')}}",
"source_v1.sql": "select 2 as result",
"source_v2.sql": "select 2 as result",
"schema.yml": schema_ref_with_version.format(
**{"latest_version": 2, "input": "ref('source')"}
),
}

def test_basic(self, project):
results = run_dbt(["run"])

results = run_dbt(["test", "--select", "model_to_test"], expect_pass=True)
assert len(results) == 1


class TestUnitTestRefWithVersionMissingRefTest:
@pytest.fixture(scope="class")
def models(self):
return {
"model_to_test.sql": "select result from {{ ref('source', v=2)}}",
"source_v1.sql": "select 2 as result",
"source_v2.sql": "select 2 as result",
"schema.yml": schema_ref_with_version.format(
**{"latest_version": 1, "input": "ref('source')"}
),
}

def test_basic(self, project):
results = run_dbt(["run"])
assert len(results) == 3
# TODO: How to capture an compilation Error? pytest.raises(CompilationError) not working
run_dbt(["test", "--select", "model_to_test"], expect_pass=False)


class TestUnitTestRefWithVersionDiffLatest:
@pytest.fixture(scope="class")
def models(self):
return {
"model_to_test.sql": "select result from {{ ref('source', v=2)}}",
"source_v1.sql": "select 2 as result",
"source_v2.sql": "select 2 as result",
"schema.yml": schema_ref_with_version.format(
**{"latest_version": 1, "input": "ref('source', v=2)"}
),
}

def test_basic(self, project):
results = run_dbt(["run"])
assert len(results) == 3
run_dbt(["test", "--select", "model_to_test"], expect_pass=True)


class TestUnitTestExplicitSeed:
Expand Down

0 comments on commit 3bb686e

Please sign in to comment.