diff --git a/cosmos/config.py b/cosmos/config.py index 29fd131f8..439b66862 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -297,8 +297,7 @@ class ExecutionConfig: Contains configuration about how to execute dbt. :param execution_mode: The execution mode for dbt. Defaults to local - :param invocation_mode: The invocation mode for the dbt command. This is only configurable for ExecutionMode.LOCAL or ExecutionMode.VIRTUALENV - execution modes. + :param invocation_mode: The invocation mode for the dbt command. This is only configurable for ExecutionMode.LOCAL. :param test_indirect_selection: The mode to configure the test behavior when performing indirect selection. :param dbt_executable_path: The path to the dbt executable for runtime execution. Defaults to dbt if available on the path. :param dbt_project_path Configures the DBT project location accessible at runtime for dag execution. This is the project path in a docker container for ExecutionMode.DOCKER or ExecutionMode.KUBERNETES. Mutually Exclusive with ProjectConfig.dbt_project_path @@ -313,8 +312,6 @@ class ExecutionConfig: project_path: Path | None = field(init=False) def __post_init__(self, dbt_project_path: str | Path | None) -> None: - if self.invocation_mode and self.execution_mode not in {ExecutionMode.LOCAL, ExecutionMode.VIRTUALENV}: - raise CosmosValueError( - "ExecutionConfig.invocation_mode is only configurable for ExecutionMode.LOCAL or ExecutionMode.VIRTUALENV modes." - ) + if self.invocation_mode and self.execution_mode != ExecutionMode.LOCAL: + raise CosmosValueError("ExecutionConfig.invocation_mode is only configurable for ExecutionMode.LOCAL.") self.project_path = Path(dbt_project_path) if dbt_project_path else None diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 7ca42c7a2..bf5c51ab0 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -116,7 +116,7 @@ class DbtLocalBaseOperator(AbstractDbtBaseOperator): def __init__( self, profile_config: ProfileConfig, - invocation_mode: InvocationMode = InvocationMode.SUBPROCESS, + invocation_mode: InvocationMode | None = None, install_deps: bool = False, callback: Callable[[str], None] | None = None, should_store_compiled_sql: bool = True, @@ -131,13 +131,9 @@ def __init__( self.invocation_mode = invocation_mode self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult] self.handle_exception: Callable[..., None] - if self.invocation_mode == InvocationMode.SUBPROCESS: - self.invoke_dbt = self.run_subprocess - self.handle_exception = self.handle_exception_subprocess - elif self.invocation_mode == InvocationMode.DBT_RUNNER: - self.invoke_dbt = self.run_dbt_runner - self.handle_exception = self.handle_exception_dbt_runner - self._dbt_runner: dbtRunner | None = None + self._dbt_runner: dbtRunner | None = None + if self.invocation_mode: + self._set_invocation_methods() kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes super().__init__(**kwargs) @@ -146,6 +142,32 @@ def subprocess_hook(self) -> FullOutputSubprocessHook: """Returns hook for running the bash command.""" return FullOutputSubprocessHook() + def _set_invocation_methods(self) -> None: + """Checks if the invocation mode is provided, then sets the associated run and exception handling methods. + If the invocation mode is not set, will try to import dbtRunner and fall back to subprocess. + """ + if self.invocation_mode == InvocationMode.SUBPROCESS: + self.invoke_dbt = self.run_subprocess + self.handle_exception = self.handle_exception_subprocess + elif self.invocation_mode == InvocationMode.DBT_RUNNER: + self.invoke_dbt = self.run_dbt_runner + self.handle_exception = self.handle_exception_dbt_runner + + def _discover_invocation_mode(self) -> None: + """Discovers the invocation mode based on the availability of dbtRunner for import. If dbtRunner is available, it will + be used since it is faster than subprocess. If dbtRunner is not available, it will fall back to subprocess. + This method is called at runtime to work in the environment where the operator is running. + """ + try: + from dbt.cli.main import dbtRunner + except ImportError: + self.invocation_mode = InvocationMode.SUBPROCESS + logger.info("Could not import dbtRunner. Falling back to subprocess for invoking dbt.") + else: + self.invocation_mode = InvocationMode.DBT_RUNNER + logger.info("dbtRunner is available. Using dbtRunner for invoking dbt.") + self._set_invocation_methods() + def handle_exception_subprocess(self, result: FullOutputSubprocessResult) -> None: if self.skip_exit_code is not None and result.exit_code == self.skip_exit_code: raise AirflowSkipException(f"dbt command returned exit code {self.skip_exit_code}. Skipping.") @@ -249,6 +271,9 @@ def run_command( """ Copies the dbt project to a temporary directory and runs the command. """ + if not self.invocation_mode: + self._discover_invocation_mode() + with tempfile.TemporaryDirectory() as tmp_project_dir: logger.info( "Cloning project to writable temp directory %s from %s", @@ -480,12 +505,6 @@ def __init__( self.on_warning_callback = on_warning_callback self.extract_issues: Callable[..., tuple[list[str], list[str]]] self.parse_number_of_warnings: Callable[..., int] - if self.invocation_mode == InvocationMode.SUBPROCESS: - self.extract_issues = lambda result: extract_log_issues(result.full_output) - self.parse_number_of_warnings = parse_number_of_warnings_subprocess - elif self.invocation_mode == InvocationMode.DBT_RUNNER: - self.extract_issues = extract_dbt_runner_issues - self.parse_number_of_warnings = parse_number_of_warnings_dbt_runner def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, context: Context) -> None: """ @@ -503,8 +522,18 @@ def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, self.on_warning_callback and self.on_warning_callback(warning_context) + def _set_test_result_parsing_methods(self) -> None: + """Sets the extract_issues and parse_number_of_warnings methods based on the invocation mode.""" + if self.invocation_mode == InvocationMode.SUBPROCESS: + self.extract_issues = lambda result: extract_log_issues(result.full_output) + self.parse_number_of_warnings = parse_number_of_warnings_subprocess + elif self.invocation_mode == InvocationMode.DBT_RUNNER: + self.extract_issues = extract_dbt_runner_issues + self.parse_number_of_warnings = parse_number_of_warnings_dbt_runner + def execute(self, context: Context) -> None: result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) + self._set_test_result_parsing_methods() number_of_warnings = self.parse_number_of_warnings(result) # type: ignore if self.on_warning_callback and number_of_warnings > 0: self._handle_warnings(result, context) diff --git a/docs/configuration/execution-config.rst b/docs/configuration/execution-config.rst index aad40c19d..dd9758d55 100644 --- a/docs/configuration/execution-config.rst +++ b/docs/configuration/execution-config.rst @@ -7,7 +7,7 @@ It does this by exposing a ``cosmos.config.ExecutionConfig`` class that you can The ``ExecutionConfig`` class takes the following arguments: - ``execution_mode``: The way dbt is run when executing within airflow. For more information, see the `execution modes <../getting_started/execution-modes.html>`_ page. -- ``invocation_mode`` (new in v1.4): The way dbt is invoked within the execution mode. This is only configurable for ``ExecutionMode.LOCAL`` and ``ExecutionMode.VIRTUALENV``. For more information, see `invocation modes <../getting_started/execution-modes.html#invocation-modes>`_. +- ``invocation_mode`` (new in v1.4): The way dbt is invoked within the execution mode. This is only configurable for ``ExecutionMode.LOCAL``. For more information, see `invocation modes <../getting_started/execution-modes.html#invocation-modes>`_. - ``test_indirect_selection``: The mode to configure the test behavior when performing indirect selection. - ``dbt_executable_path``: The path to the dbt executable for dag generation. Defaults to dbt if available on the path. - ``dbt_project_path``: Configures the dbt project location accessible at runtime for dag execution. This is the project path in a docker container for ``ExecutionMode.DOCKER`` or ``ExecutionMode.KUBERNETES``. Mutually exclusive with ``ProjectConfig.dbt_project_path``. diff --git a/docs/getting_started/execution-modes.rst b/docs/getting_started/execution-modes.rst index f34a23c32..92c542a1d 100644 --- a/docs/getting_started/execution-modes.rst +++ b/docs/getting_started/execution-modes.rst @@ -187,12 +187,12 @@ Invocation Modes ================ .. versionadded:: 1.4 -For ``ExecutionMode.LOCAL`` and ``ExecutionMode.VIRTUALENV`` execution modes, Cosmos supports two invocation modes for running dbt: +For ``ExecutionMode.LOCAL`` execution mode, Cosmos supports two invocation modes for running dbt: -1. ``InvocationMode.SUBPROCESS``: This is currently the default mode and does not need to be specified. In this mode, Cosmos runs dbt cli commands using the Python ``subprocess`` module and parses the output to capture logs and to raise exceptions. +1. ``InvocationMode.SUBPROCESS``: In this mode, Cosmos runs dbt cli commands using the Python ``subprocess`` module and parses the output to capture logs and to raise exceptions. 2. ``InvocationMode.DBT_RUNNER``: In this mode, Cosmos uses the ``dbtRunner`` available for `dbt programmatic invocations `__ to run dbt commands. \ - In order to use this mode, dbt must be installed in the same environment, either local or virtualenv for the worker. This mode does not have the overhead of spawning new subprocesses or parsing the output of dbt commands and can be expected to be faster than ``InvocationMode.SUBPROCESS``. \ + In order to use this mode, dbt must be installed in the same environment, either local or virtualenv for the worker. This mode does not have the overhead of spawning new subprocesses or parsing the output of dbt commands and is faster than ``InvocationMode.SUBPROCESS``. \ This mode requires dbt version 1.5.0 or higher. The invocation mode can be set in the ``ExecutionConfig`` as shown below: @@ -208,3 +208,5 @@ The invocation mode can be set in the ``ExecutionConfig`` as shown below: invocation_mode=InvocationMode.DBT_RUNNER, ), ) + +If the invocation mode is not set, Cosmos will attempt to use ``InvocationMode.DBT_RUNNER`` if dbt is installed in the same environment as the worker, otherwise it will default to ``InvocationMode.SUBPROCESS``. diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index db0310d67..b33510f4b 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -135,17 +135,44 @@ def test_dbt_base_operator_add_user_supplied_global_flags() -> None: (InvocationMode.DBT_RUNNER, "run_dbt_runner", "handle_exception_dbt_runner"), ], ) -def test_dbt_base_operator_invocation_methods_set(invocation_mode, invoke_dbt_method, handle_exception_method): +def test_dbt_base_operator_set_invocation_methods(invocation_mode, invoke_dbt_method, handle_exception_method): """Tests that the right methods are mapped to DbtLocalBaseOperator.invoke_dbt and - DbtLocalBaseOperator.handle_exception based on the invocation mode passed. + DbtLocalBaseOperator.handle_exception when a known invocation mode passed. """ dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", invocation_mode=invocation_mode ) + dbt_base_operator._set_invocation_methods() assert dbt_base_operator.invoke_dbt.__name__ == invoke_dbt_method assert dbt_base_operator.handle_exception.__name__ == handle_exception_method +@pytest.mark.parametrize( + "can_import_dbt, invoke_dbt_method, handle_exception_method", + [ + (False, "run_subprocess", "handle_exception_subprocess"), + (True, "run_dbt_runner", "handle_exception_dbt_runner"), + ], +) +def test_dbt_base_operator_discover_invocation_mode(can_import_dbt, invoke_dbt_method, handle_exception_method): + """Tests that the right methods are mapped to DbtLocalBaseOperator.invoke_dbt and + DbtLocalBaseOperator.handle_exception if dbt can be imported or not. + """ + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, task_id="my-task", project_dir="my/dir" + ) + with patch.dict(sys.modules, {"dbt.cli.main": MagicMock()} if can_import_dbt else {"dbt.cli.main": None}): + dbt_base_operator = ConcreteDbtLocalBaseOperator( + profile_config=profile_config, task_id="my-task", project_dir="my/dir" + ) + dbt_base_operator._discover_invocation_mode() + assert dbt_base_operator.invocation_mode == ( + InvocationMode.DBT_RUNNER if can_import_dbt else InvocationMode.SUBPROCESS + ) + assert dbt_base_operator.invoke_dbt.__name__ == invoke_dbt_method + assert dbt_base_operator.handle_exception.__name__ == handle_exception_method + + @pytest.mark.parametrize( "indirect_selection_type", [None, "cautious", "buildable", "empty"], @@ -245,11 +272,14 @@ def test_dbt_base_operator_run_dbt_runner_is_cached(mock_chdir): "No exception raised", ], ) -def test_dbt_base_operator_exception_handling(skip_exception, exception_code_returned, expected_exception) -> None: +def test_dbt_base_operator_exception_handling_subprocess( + skip_exception, exception_code_returned, expected_exception +) -> None: dbt_base_operator = ConcreteDbtLocalBaseOperator( profile_config=profile_config, task_id="my-task", project_dir="my/dir", + invocation_mode=InvocationMode.SUBPROCESS, ) if expected_exception: with pytest.raises(expected_exception): @@ -304,7 +334,7 @@ def test_dbt_base_operator_get_env(p_context_to_airflow_vars: MagicMock) -> None @patch("cosmos.operators.local.extract_log_issues") -def test_dbt_test_local_operator_invocation_mode_functions(mock_extract_log_issues): +def test_dbt_test_local_operator_invocation_mode_methods(mock_extract_log_issues): # test subprocess invocation mode operator = DbtTestLocalOperator( profile_config=profile_config, @@ -312,6 +342,7 @@ def test_dbt_test_local_operator_invocation_mode_functions(mock_extract_log_issu task_id="my-task", project_dir="my/dir", ) + operator._set_test_result_parsing_methods() assert operator.parse_number_of_warnings == parse_number_of_warnings_subprocess result = MagicMock(full_output="some output") operator.extract_issues(result) @@ -324,6 +355,7 @@ def test_dbt_test_local_operator_invocation_mode_functions(mock_extract_log_issu task_id="my-task", project_dir="my/dir", ) + operator._set_test_result_parsing_methods() assert operator.extract_issues == extract_dbt_runner_issues assert operator.parse_number_of_warnings == parse_number_of_warnings_dbt_runner @@ -519,7 +551,13 @@ def test_store_compiled_sql() -> None: ) @patch("cosmos.operators.local.DbtLocalBaseOperator.build_and_run_cmd") def test_operator_execute_with_flags(mock_build_and_run_cmd, operator_class, kwargs, expected_call_kwargs): - task = operator_class(profile_config=profile_config, task_id="my-task", project_dir="my/dir", **kwargs) + task = operator_class( + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + invocation_mode=InvocationMode.DBT_RUNNER, + **kwargs, + ) task.execute(context={}) mock_build_and_run_cmd.assert_called_once_with(**expected_call_kwargs) @@ -548,6 +586,7 @@ def test_operator_execute_without_flags(mock_build_and_run_cmd, operator_class): profile_config=profile_config, task_id="my-task", project_dir="my/dir", + invocation_mode=InvocationMode.DBT_RUNNER, **operator_class_kwargs.get(operator_class, {}), ) task.execute(context={}) @@ -616,8 +655,15 @@ def test_dbt_docs_gcs_local_operator(): @patch("cosmos.operators.local.DbtLocalBaseOperator.handle_exception_subprocess") @patch("cosmos.config.ProfileConfig.ensure_profile") @patch("cosmos.operators.local.DbtLocalBaseOperator.run_subprocess") +@patch("cosmos.operators.local.DbtLocalBaseOperator.run_dbt_runner") +@pytest.mark.parametrize("invocation_mode", [InvocationMode.SUBPROCESS, InvocationMode.DBT_RUNNER]) def test_operator_execute_deps_parameters( - mock_build_and_run_cmd, mock_ensure_profile, mock_exception_handling, mock_store_compiled_sql + mock_dbt_runner, + mock_subprocess, + mock_ensure_profile, + mock_exception_handling, + mock_store_compiled_sql, + invocation_mode, ): expected_call_kwargs = [ "/usr/local/bin/dbt", @@ -636,10 +682,14 @@ def test_operator_execute_deps_parameters( install_deps=True, emit_datasets=False, dbt_executable_path="/usr/local/bin/dbt", + invocation_mode=invocation_mode, ) mock_ensure_profile.return_value.__enter__.return_value = (Path("/path/to/profile"), {"ENV_VAR": "value"}) task.execute(context={"task_instance": MagicMock()}) - assert mock_build_and_run_cmd.call_args_list[0].kwargs["command"] == expected_call_kwargs + if invocation_mode == InvocationMode.SUBPROCESS: + assert mock_subprocess.call_args_list[0].kwargs["command"] == expected_call_kwargs + elif invocation_mode == InvocationMode.DBT_RUNNER: + mock_dbt_runner.all_args_list[0].kwargs["command"] == expected_call_kwargs def test_dbt_docs_local_operator_with_static_flag(): diff --git a/tests/operators/test_virtualenv.py b/tests/operators/test_virtualenv.py index 9d180b3e2..036f162de 100644 --- a/tests/operators/test_virtualenv.py +++ b/tests/operators/test_virtualenv.py @@ -7,6 +7,7 @@ from cosmos.config import ProfileConfig from cosmos.profiles import PostgresUserPasswordProfileMapping +from cosmos.constants import InvocationMode profile_config = ProfileConfig( profile_name="default", @@ -53,6 +54,7 @@ def test_run_command( py_system_site_packages=False, py_requirements=["dbt-postgres==1.6.0b1"], emit_datasets=False, + invocation_mode=InvocationMode.SUBPROCESS, ) assert venv_operator._venv_tmp_dir is None # Otherwise we are creating empty directories during DAG parsing time # and not deleting them diff --git a/tests/test_config.py b/tests/test_config.py index d7c456938..b93ad2627 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -203,7 +203,7 @@ def test_render_config_env_vars_deprecated(): "execution_mode, expectation", [ (ExecutionMode.LOCAL, does_not_raise()), - (ExecutionMode.VIRTUALENV, does_not_raise()), + (ExecutionMode.VIRTUALENV, pytest.raises(CosmosValueError)), (ExecutionMode.KUBERNETES, pytest.raises(CosmosValueError)), (ExecutionMode.DOCKER, pytest.raises(CosmosValueError)), (ExecutionMode.AZURE_CONTAINER_INSTANCE, pytest.raises(CosmosValueError)),