Skip to content

Commit

Permalink
Call adapter.pre_model_hook + adapter.post_model_hook within tests (#…
Browse files Browse the repository at this point in the history
…10258)

* init push for issue 10198

* add changelog

* add unit tests based on michelle example

* add data_tests, and post_hook unit tests

* pull creating macro_func out of try call

* revert last commit

* pull macro_func definition back out of try

* update code formatting
  • Loading branch information
McKnight-42 authored Jun 14, 2024
1 parent 100352d commit 27b2f96
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 3 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240606-112334.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: add pre_model and post_model hook calls to data and unit tests to be able to provide extra config options
time: 2024-06-06T11:23:34.758675-05:00
custom:
Author: McKnight-42
Issue: "10198"
16 changes: 13 additions & 3 deletions core/dbt/task/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def before_execute(self):
def execute_data_test(self, data_test: TestNode, manifest: Manifest) -> TestResultData:
context = generate_runtime_model_context(data_test, self.config, manifest)

hook_ctx = self.adapter.pre_model_hook(context)

materialization_macro = manifest.find_materialization_macro_by_name(
self.config.project_name, data_test.get_materialization(), self.adapter.type()
)
Expand All @@ -142,8 +144,12 @@ def execute_data_test(self, data_test: TestNode, manifest: Manifest) -> TestResu

# generate materialization macro
macro_func = MacroGenerator(materialization_macro, context)
# execute materialization macro
macro_func()
try:
# execute materialization macro
macro_func()
finally:
self.adapter.post_model_hook(context, hook_ctx)

# load results from context
# could eventually be returned directly by materialization
result = context["load_result"]("main")
Expand Down Expand Up @@ -198,6 +204,8 @@ def execute_unit_test(
# materialization, not compile the node.compiled_code
context = generate_runtime_model_context(unit_test_node, self.config, unit_test_manifest)

hook_ctx = self.adapter.pre_model_hook(context)

materialization_macro = unit_test_manifest.find_materialization_macro_by_name(
self.config.project_name, unit_test_node.get_materialization(), self.adapter.type()
)
Expand All @@ -215,14 +223,16 @@ def execute_unit_test(

# generate materialization macro
macro_func = MacroGenerator(materialization_macro, context)
# execute materialization macro
try:
# execute materialization macro
macro_func()
except DbtBaseException as e:
raise DbtRuntimeError(
f"An error occurred during execution of unit test '{unit_test_def.name}'. "
f"There may be an error in the unit test definition: check the data types.\n {e}"
)
finally:
self.adapter.post_model_hook(context, hook_ctx)

# load results from context
# could eventually be returned directly by materialization
Expand Down
111 changes: 111 additions & 0 deletions tests/functional/data_tests/test_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from unittest import mock

import pytest

from dbt.tests.util import run_dbt, run_dbt_and_capture
from dbt_common.exceptions import CompilationError

orders_csv = """order_id,order_date,customer_id
1,2024-06-01,1001
2,2024-06-02,1002
3,2024-06-03,1003
4,2024-06-04,1004
"""


orders_model_sql = """
with source as (
select
order_id,
order_date,
customer_id
from {{ ref('seed_orders') }}
),
final as (
select
order_id,
order_date,
customer_id
from source
)
select * from final
"""


orders_test_sql = """
select *
from {{ ref('orders') }}
where order_id is null
"""


class BaseSingularTestHooks:
@pytest.fixture(scope="class")
def seeds(self):
return {"seed_orders.csv": orders_csv}

@pytest.fixture(scope="class")
def models(self):
return {"orders.sql": orders_model_sql}

@pytest.fixture(scope="class")
def tests(self):
return {"orders_test.sql": orders_test_sql}


class TestSingularTestPreHook(BaseSingularTestHooks):
def test_data_test_runs_adapter_pre_hook_pass(self, project):
results = run_dbt(["seed"])
assert len(results) == 1

results = run_dbt(["run"])
assert len(results) == 1

mock_pre_model_hook = mock.Mock()
with mock.patch.object(type(project.adapter), "pre_model_hook", mock_pre_model_hook):
results = run_dbt(["test"], expect_pass=True)
assert len(results) == 1
mock_pre_model_hook.assert_called_once()

def test_data_test_runs_adapter_pre_hook_fails(self, project):
results = run_dbt(["seed"])
assert len(results) == 1

results = run_dbt(["run"])
assert len(results) == 1

mock_pre_model_hook = mock.Mock()
mock_pre_model_hook.side_effect = CompilationError("exception from adapter.pre_model_hook")
with mock.patch.object(type(project.adapter), "pre_model_hook", mock_pre_model_hook):
(_, log_output) = run_dbt_and_capture(["test"], expect_pass=False)
assert "exception from adapter.pre_model_hook" in log_output


class TestSingularTestPostHook(BaseSingularTestHooks):
def test_data_test_runs_adapter_post_hook_pass(self, project):
results = run_dbt(["seed"])
assert len(results) == 1

results = run_dbt(["run"])
assert len(results) == 1

mock_post_model_hook = mock.Mock()
with mock.patch.object(type(project.adapter), "post_model_hook", mock_post_model_hook):
results = run_dbt(["test"], expect_pass=True)
assert len(results) == 1
mock_post_model_hook.assert_called_once()

def test_data_test_runs_adapter_post_hook_fails(self, project):
results = run_dbt(["seed"])
assert len(results) == 1

results = run_dbt(["run"])
assert len(results) == 1

mock_post_model_hook = mock.Mock()
mock_post_model_hook.side_effect = CompilationError(
"exception from adapter.post_model_hook"
)
with mock.patch.object(type(project.adapter), "post_model_hook", mock_post_model_hook):
(_, log_output) = run_dbt_and_capture(["test"], expect_pass=False)
assert "exception from adapter.post_model_hook" in log_output
17 changes: 17 additions & 0 deletions tests/functional/unit_testing/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,23 @@
tags: test_this
"""

test_my_model_pass_yml = """
unit_tests:
- name: test_my_model
model: my_model
given:
- input: ref('my_model_a')
rows:
- {id: 1, a: 1}
- input: ref('my_model_b')
rows:
- {id: 1, b: 2}
- {id: 2, b: 2}
expect:
rows:
- {c: 3}
"""


test_my_model_simple_fixture_yml = """
unit_tests:
Expand Down
75 changes: 75 additions & 0 deletions tests/functional/unit_testing/test_ut_adapter_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from unittest import mock

import pytest

from dbt.tests.util import run_dbt, run_dbt_and_capture
from dbt_common.exceptions import CompilationError
from tests.functional.unit_testing.fixtures import (
my_model_a_sql,
my_model_b_sql,
my_model_sql,
test_my_model_pass_yml,
)


class BaseUnitTestAdapterHook:
@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_model_a.sql": my_model_a_sql,
"my_model_b.sql": my_model_b_sql,
"test_my_model.yml": test_my_model_pass_yml,
}


class TestUnitTestAdapterPreHook(BaseUnitTestAdapterHook):
def test_unit_test_runs_adapter_pre_hook_passes(self, project):
results = run_dbt(["run"])
assert len(results) == 3

mock_pre_model_hook = mock.Mock()
with mock.patch.object(type(project.adapter), "pre_model_hook", mock_pre_model_hook):
results = run_dbt(["test", "--select", "test_name:test_my_model"], expect_pass=True)

assert len(results) == 1
mock_pre_model_hook.assert_called_once()

def test_unit_test_runs_adapter_pre_hook_fails(self, project):
results = run_dbt(["run"])
assert len(results) == 3

mock_pre_model_hook = mock.Mock()
mock_pre_model_hook.side_effect = CompilationError("exception from adapter.pre_model_hook")
with mock.patch.object(type(project.adapter), "pre_model_hook", mock_pre_model_hook):
(_, log_output) = run_dbt_and_capture(
["test", "--select", "test_name:test_my_model"], expect_pass=False
)
assert "exception from adapter.pre_model_hook" in log_output


class TestUnitTestAdapterPostHook(BaseUnitTestAdapterHook):
def test_unit_test_runs_adapter_post_hook_pass(self, project):
results = run_dbt(["run"])
assert len(results) == 3

mock_post_model_hook = mock.Mock()
with mock.patch.object(type(project.adapter), "post_model_hook", mock_post_model_hook):
results = run_dbt(["test", "--select", "test_name:test_my_model"], expect_pass=True)

assert len(results) == 1
mock_post_model_hook.assert_called_once()

def test_unit_test_runs_adapter_post_hook_fails(self, project):
results = run_dbt(["run"])
assert len(results) == 3

mock_post_model_hook = mock.Mock()
mock_post_model_hook.side_effect = CompilationError(
"exception from adapter.post_model_hook"
)
with mock.patch.object(type(project.adapter), "post_model_hook", mock_post_model_hook):
(_, log_output) = run_dbt_and_capture(
["test", "--select", "test_name:test_my_model"], expect_pass=False
)
assert "exception from adapter.post_model_hook" in log_output

0 comments on commit 27b2f96

Please sign in to comment.