diff --git a/tests/unit/task/test_run.py b/tests/unit/task/test_run.py index 24788cf2a6b..b33e6f57ffe 100644 --- a/tests/unit/task/test_run.py +++ b/tests/unit/task/test_run.py @@ -236,25 +236,39 @@ class Relation: assert model_runner._is_incremental(model) == expectation @pytest.mark.parametrize( - "has_relation,concurrent_batches,has_this,expectation", + "adapter_microbatch_concurrency,has_relation,concurrent_batches,has_this,expectation", [ - (True, None, False, True), - (True, None, True, False), - (True, True, False, True), - (True, True, True, True), - (True, False, False, False), - (True, False, True, False), - (False, None, False, False), - (False, None, True, False), - (False, True, False, False), - (False, True, True, False), - (False, False, False, False), - (False, False, True, False), + (True, True, None, False, True), + (True, True, None, True, False), + (True, True, True, False, True), + (True, True, True, True, True), + (True, True, False, False, False), + (True, True, False, True, False), + (True, False, None, False, False), + (True, False, None, True, False), + (True, False, True, False, False), + (True, False, True, True, False), + (True, False, False, False, False), + (True, False, False, True, False), + (False, True, None, False, False), + (False, True, None, True, False), + (False, True, True, False, False), + (False, True, True, True, False), + (False, True, False, False, False), + (False, True, False, True, False), + (False, False, None, False, False), + (False, False, None, True, False), + (False, False, True, False, False), + (False, False, True, True, False), + (False, False, False, False, False), + (False, False, False, True, False), ], ) def test__should_run_in_parallel( self, + mocker: MockerFixture, model_runner: MicrobatchModelRunner, + adapter_microbatch_concurrency: bool, has_relation: bool, concurrent_batches: Optional[bool], has_this: bool, @@ -262,6 +276,8 @@ def test__should_run_in_parallel( ) -> None: model_runner.node._has_this = has_this model_runner.node.config = ModelConfig(concurrent_batches=concurrent_batches) + mocked_supports = mocker.patch.object(model_runner.adapter, "supports") + mocked_supports.return_value = adapter_microbatch_concurrency # Assert result of _should_run_in_parallel assert model_runner._should_run_in_parallel(has_relation) == expectation