Skip to content

Commit

Permalink
add invocation mode discovery if none selected
Browse files Browse the repository at this point in the history
  • Loading branch information
jbandoro committed Feb 17, 2024
1 parent 7a3b182 commit fd92032
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 32 deletions.
9 changes: 3 additions & 6 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,7 @@ class ExecutionConfig:
Contains configuration about how to execute dbt.
:param execution_mode: The execution mode for dbt. Defaults to local
:param invocation_mode: The invocation mode for the dbt command. This is only configurable for ExecutionMode.LOCAL or ExecutionMode.VIRTUALENV
execution modes.
:param invocation_mode: The invocation mode for the dbt command. This is only configurable for ExecutionMode.LOCAL.
:param test_indirect_selection: The mode to configure the test behavior when performing indirect selection.
:param dbt_executable_path: The path to the dbt executable for runtime execution. Defaults to dbt if available on the path.
:param dbt_project_path Configures the DBT project location accessible at runtime for dag execution. This is the project path in a docker container for ExecutionMode.DOCKER or ExecutionMode.KUBERNETES. Mutually Exclusive with ProjectConfig.dbt_project_path
Expand All @@ -313,8 +312,6 @@ class ExecutionConfig:
project_path: Path | None = field(init=False)

def __post_init__(self, dbt_project_path: str | Path | None) -> None:
if self.invocation_mode and self.execution_mode not in {ExecutionMode.LOCAL, ExecutionMode.VIRTUALENV}:
raise CosmosValueError(
"ExecutionConfig.invocation_mode is only configurable for ExecutionMode.LOCAL or ExecutionMode.VIRTUALENV modes."
)
if self.invocation_mode and self.execution_mode != ExecutionMode.LOCAL:
raise CosmosValueError("ExecutionConfig.invocation_mode is only configurable for ExecutionMode.LOCAL.")
self.project_path = Path(dbt_project_path) if dbt_project_path else None
57 changes: 43 additions & 14 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class DbtLocalBaseOperator(AbstractDbtBaseOperator):
def __init__(
self,
profile_config: ProfileConfig,
invocation_mode: InvocationMode = InvocationMode.SUBPROCESS,
invocation_mode: InvocationMode | None = None,
install_deps: bool = False,
callback: Callable[[str], None] | None = None,
should_store_compiled_sql: bool = True,
Expand All @@ -131,13 +131,9 @@ def __init__(
self.invocation_mode = invocation_mode
self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult]
self.handle_exception: Callable[..., None]
if self.invocation_mode == InvocationMode.SUBPROCESS:
self.invoke_dbt = self.run_subprocess
self.handle_exception = self.handle_exception_subprocess
elif self.invocation_mode == InvocationMode.DBT_RUNNER:
self.invoke_dbt = self.run_dbt_runner
self.handle_exception = self.handle_exception_dbt_runner
self._dbt_runner: dbtRunner | None = None
self._dbt_runner: dbtRunner | None = None
if self.invocation_mode:
self._set_invocation_methods()
kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes
super().__init__(**kwargs)

Expand All @@ -146,6 +142,32 @@ def subprocess_hook(self) -> FullOutputSubprocessHook:
"""Returns hook for running the bash command."""
return FullOutputSubprocessHook()

def _set_invocation_methods(self) -> None:
"""Checks if the invocation mode is provided, then sets the associated run and exception handling methods.
If the invocation mode is not set, will try to import dbtRunner and fall back to subprocess.
"""
if self.invocation_mode == InvocationMode.SUBPROCESS:
self.invoke_dbt = self.run_subprocess
self.handle_exception = self.handle_exception_subprocess
elif self.invocation_mode == InvocationMode.DBT_RUNNER:
self.invoke_dbt = self.run_dbt_runner
self.handle_exception = self.handle_exception_dbt_runner

def _discover_invocation_mode(self) -> None:
"""Discovers the invocation mode based on the availability of dbtRunner for import. If dbtRunner is available, it will
be used since it is faster than subprocess. If dbtRunner is not available, it will fall back to subprocess.
This method is called at runtime to work in the environment where the operator is running.
"""
try:
from dbt.cli.main import dbtRunner
except ImportError:
self.invocation_mode = InvocationMode.SUBPROCESS
logger.info("Could not import dbtRunner. Falling back to subprocess for invoking dbt.")
else:
self.invocation_mode = InvocationMode.DBT_RUNNER
logger.info("dbtRunner is available. Using dbtRunner for invoking dbt.")
self._set_invocation_methods()

def handle_exception_subprocess(self, result: FullOutputSubprocessResult) -> None:
if self.skip_exit_code is not None and result.exit_code == self.skip_exit_code:
raise AirflowSkipException(f"dbt command returned exit code {self.skip_exit_code}. Skipping.")
Expand Down Expand Up @@ -249,6 +271,9 @@ def run_command(
"""
Copies the dbt project to a temporary directory and runs the command.
"""
if not self.invocation_mode:
self._discover_invocation_mode()

with tempfile.TemporaryDirectory() as tmp_project_dir:
logger.info(
"Cloning project to writable temp directory %s from %s",
Expand Down Expand Up @@ -480,12 +505,6 @@ def __init__(
self.on_warning_callback = on_warning_callback
self.extract_issues: Callable[..., tuple[list[str], list[str]]]
self.parse_number_of_warnings: Callable[..., int]
if self.invocation_mode == InvocationMode.SUBPROCESS:
self.extract_issues = lambda result: extract_log_issues(result.full_output)
self.parse_number_of_warnings = parse_number_of_warnings_subprocess
elif self.invocation_mode == InvocationMode.DBT_RUNNER:
self.extract_issues = extract_dbt_runner_issues
self.parse_number_of_warnings = parse_number_of_warnings_dbt_runner

def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, context: Context) -> None:
"""
Expand All @@ -503,8 +522,18 @@ def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult,

self.on_warning_callback and self.on_warning_callback(warning_context)

def _set_test_result_parsing_methods(self) -> None:
"""Sets the extract_issues and parse_number_of_warnings methods based on the invocation mode."""
if self.invocation_mode == InvocationMode.SUBPROCESS:
self.extract_issues = lambda result: extract_log_issues(result.full_output)
self.parse_number_of_warnings = parse_number_of_warnings_subprocess
elif self.invocation_mode == InvocationMode.DBT_RUNNER:
self.extract_issues = extract_dbt_runner_issues
self.parse_number_of_warnings = parse_number_of_warnings_dbt_runner

def execute(self, context: Context) -> None:
result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())
self._set_test_result_parsing_methods()
number_of_warnings = self.parse_number_of_warnings(result) # type: ignore
if self.on_warning_callback and number_of_warnings > 0:
self._handle_warnings(result, context)
Expand Down
2 changes: 1 addition & 1 deletion docs/configuration/execution-config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ It does this by exposing a ``cosmos.config.ExecutionConfig`` class that you can
The ``ExecutionConfig`` class takes the following arguments:

- ``execution_mode``: The way dbt is run when executing within airflow. For more information, see the `execution modes <../getting_started/execution-modes.html>`_ page.
- ``invocation_mode`` (new in v1.4): The way dbt is invoked within the execution mode. This is only configurable for ``ExecutionMode.LOCAL`` and ``ExecutionMode.VIRTUALENV``. For more information, see `invocation modes <../getting_started/execution-modes.html#invocation-modes>`_.
- ``invocation_mode`` (new in v1.4): The way dbt is invoked within the execution mode. This is only configurable for ``ExecutionMode.LOCAL``. For more information, see `invocation modes <../getting_started/execution-modes.html#invocation-modes>`_.
- ``test_indirect_selection``: The mode to configure the test behavior when performing indirect selection.
- ``dbt_executable_path``: The path to the dbt executable for dag generation. Defaults to dbt if available on the path.
- ``dbt_project_path``: Configures the dbt project location accessible at runtime for dag execution. This is the project path in a docker container for ``ExecutionMode.DOCKER`` or ``ExecutionMode.KUBERNETES``. Mutually exclusive with ``ProjectConfig.dbt_project_path``.
8 changes: 5 additions & 3 deletions docs/getting_started/execution-modes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,12 @@ Invocation Modes
================
.. versionadded:: 1.4

For ``ExecutionMode.LOCAL`` and ``ExecutionMode.VIRTUALENV`` execution modes, Cosmos supports two invocation modes for running dbt:
For ``ExecutionMode.LOCAL`` execution mode, Cosmos supports two invocation modes for running dbt:

1. ``InvocationMode.SUBPROCESS``: This is currently the default mode and does not need to be specified. In this mode, Cosmos runs dbt cli commands using the Python ``subprocess`` module and parses the output to capture logs and to raise exceptions.
1. ``InvocationMode.SUBPROCESS``: In this mode, Cosmos runs dbt cli commands using the Python ``subprocess`` module and parses the output to capture logs and to raise exceptions.

2. ``InvocationMode.DBT_RUNNER``: In this mode, Cosmos uses the ``dbtRunner`` available for `dbt programmatic invocations <https://docs.getdbt.com/reference/programmatic-invocations>`__ to run dbt commands. \
In order to use this mode, dbt must be installed in the same environment, either local or virtualenv for the worker. This mode does not have the overhead of spawning new subprocesses or parsing the output of dbt commands and can be expected to be faster than ``InvocationMode.SUBPROCESS``. \
In order to use this mode, dbt must be installed in the same environment, either local or virtualenv for the worker. This mode does not have the overhead of spawning new subprocesses or parsing the output of dbt commands and is faster than ``InvocationMode.SUBPROCESS``. \
This mode requires dbt version 1.5.0 or higher.

The invocation mode can be set in the ``ExecutionConfig`` as shown below:
Expand All @@ -208,3 +208,5 @@ The invocation mode can be set in the ``ExecutionConfig`` as shown below:
invocation_mode=InvocationMode.DBT_RUNNER,
),
)
If the invocation mode is not set, Cosmos will attempt to use ``InvocationMode.DBT_RUNNER`` if dbt is installed in the same environment as the worker, otherwise it will default to ``InvocationMode.SUBPROCESS``.
64 changes: 57 additions & 7 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,44 @@ def test_dbt_base_operator_add_user_supplied_global_flags() -> None:
(InvocationMode.DBT_RUNNER, "run_dbt_runner", "handle_exception_dbt_runner"),
],
)
def test_dbt_base_operator_invocation_methods_set(invocation_mode, invoke_dbt_method, handle_exception_method):
def test_dbt_base_operator_set_invocation_methods(invocation_mode, invoke_dbt_method, handle_exception_method):
"""Tests that the right methods are mapped to DbtLocalBaseOperator.invoke_dbt and
DbtLocalBaseOperator.handle_exception based on the invocation mode passed.
DbtLocalBaseOperator.handle_exception when a known invocation mode passed.
"""
dbt_base_operator = ConcreteDbtLocalBaseOperator(
profile_config=profile_config, task_id="my-task", project_dir="my/dir", invocation_mode=invocation_mode
)
dbt_base_operator._set_invocation_methods()
assert dbt_base_operator.invoke_dbt.__name__ == invoke_dbt_method
assert dbt_base_operator.handle_exception.__name__ == handle_exception_method


@pytest.mark.parametrize(
"can_import_dbt, invoke_dbt_method, handle_exception_method",
[
(False, "run_subprocess", "handle_exception_subprocess"),
(True, "run_dbt_runner", "handle_exception_dbt_runner"),
],
)
def test_dbt_base_operator_discover_invocation_mode(can_import_dbt, invoke_dbt_method, handle_exception_method):
"""Tests that the right methods are mapped to DbtLocalBaseOperator.invoke_dbt and
DbtLocalBaseOperator.handle_exception if dbt can be imported or not.
"""
dbt_base_operator = ConcreteDbtLocalBaseOperator(
profile_config=profile_config, task_id="my-task", project_dir="my/dir"
)
with patch.dict(sys.modules, {"dbt.cli.main": MagicMock()} if can_import_dbt else {"dbt.cli.main": None}):
dbt_base_operator = ConcreteDbtLocalBaseOperator(
profile_config=profile_config, task_id="my-task", project_dir="my/dir"
)
dbt_base_operator._discover_invocation_mode()
assert dbt_base_operator.invocation_mode == (
InvocationMode.DBT_RUNNER if can_import_dbt else InvocationMode.SUBPROCESS
)
assert dbt_base_operator.invoke_dbt.__name__ == invoke_dbt_method
assert dbt_base_operator.handle_exception.__name__ == handle_exception_method


@pytest.mark.parametrize(
"indirect_selection_type",
[None, "cautious", "buildable", "empty"],
Expand Down Expand Up @@ -245,11 +272,14 @@ def test_dbt_base_operator_run_dbt_runner_is_cached(mock_chdir):
"No exception raised",
],
)
def test_dbt_base_operator_exception_handling(skip_exception, exception_code_returned, expected_exception) -> None:
def test_dbt_base_operator_exception_handling_subprocess(
skip_exception, exception_code_returned, expected_exception
) -> None:
dbt_base_operator = ConcreteDbtLocalBaseOperator(
profile_config=profile_config,
task_id="my-task",
project_dir="my/dir",
invocation_mode=InvocationMode.SUBPROCESS,
)
if expected_exception:
with pytest.raises(expected_exception):
Expand Down Expand Up @@ -304,14 +334,15 @@ def test_dbt_base_operator_get_env(p_context_to_airflow_vars: MagicMock) -> None


@patch("cosmos.operators.local.extract_log_issues")
def test_dbt_test_local_operator_invocation_mode_functions(mock_extract_log_issues):
def test_dbt_test_local_operator_invocation_mode_methods(mock_extract_log_issues):
# test subprocess invocation mode
operator = DbtTestLocalOperator(
profile_config=profile_config,
invocation_mode=InvocationMode.SUBPROCESS,
task_id="my-task",
project_dir="my/dir",
)
operator._set_test_result_parsing_methods()
assert operator.parse_number_of_warnings == parse_number_of_warnings_subprocess
result = MagicMock(full_output="some output")
operator.extract_issues(result)
Expand All @@ -324,6 +355,7 @@ def test_dbt_test_local_operator_invocation_mode_functions(mock_extract_log_issu
task_id="my-task",
project_dir="my/dir",
)
operator._set_test_result_parsing_methods()
assert operator.extract_issues == extract_dbt_runner_issues
assert operator.parse_number_of_warnings == parse_number_of_warnings_dbt_runner

Expand Down Expand Up @@ -519,7 +551,13 @@ def test_store_compiled_sql() -> None:
)
@patch("cosmos.operators.local.DbtLocalBaseOperator.build_and_run_cmd")
def test_operator_execute_with_flags(mock_build_and_run_cmd, operator_class, kwargs, expected_call_kwargs):
task = operator_class(profile_config=profile_config, task_id="my-task", project_dir="my/dir", **kwargs)
task = operator_class(
profile_config=profile_config,
task_id="my-task",
project_dir="my/dir",
invocation_mode=InvocationMode.DBT_RUNNER,
**kwargs,
)
task.execute(context={})
mock_build_and_run_cmd.assert_called_once_with(**expected_call_kwargs)

Expand Down Expand Up @@ -548,6 +586,7 @@ def test_operator_execute_without_flags(mock_build_and_run_cmd, operator_class):
profile_config=profile_config,
task_id="my-task",
project_dir="my/dir",
invocation_mode=InvocationMode.DBT_RUNNER,
**operator_class_kwargs.get(operator_class, {}),
)
task.execute(context={})
Expand Down Expand Up @@ -616,8 +655,15 @@ def test_dbt_docs_gcs_local_operator():
@patch("cosmos.operators.local.DbtLocalBaseOperator.handle_exception_subprocess")
@patch("cosmos.config.ProfileConfig.ensure_profile")
@patch("cosmos.operators.local.DbtLocalBaseOperator.run_subprocess")
@patch("cosmos.operators.local.DbtLocalBaseOperator.run_dbt_runner")
@pytest.mark.parametrize("invocation_mode", [InvocationMode.SUBPROCESS, InvocationMode.DBT_RUNNER])
def test_operator_execute_deps_parameters(
mock_build_and_run_cmd, mock_ensure_profile, mock_exception_handling, mock_store_compiled_sql
mock_dbt_runner,
mock_subprocess,
mock_ensure_profile,
mock_exception_handling,
mock_store_compiled_sql,
invocation_mode,
):
expected_call_kwargs = [
"/usr/local/bin/dbt",
Expand All @@ -636,10 +682,14 @@ def test_operator_execute_deps_parameters(
install_deps=True,
emit_datasets=False,
dbt_executable_path="/usr/local/bin/dbt",
invocation_mode=invocation_mode,
)
mock_ensure_profile.return_value.__enter__.return_value = (Path("/path/to/profile"), {"ENV_VAR": "value"})
task.execute(context={"task_instance": MagicMock()})
assert mock_build_and_run_cmd.call_args_list[0].kwargs["command"] == expected_call_kwargs
if invocation_mode == InvocationMode.SUBPROCESS:
assert mock_subprocess.call_args_list[0].kwargs["command"] == expected_call_kwargs
elif invocation_mode == InvocationMode.DBT_RUNNER:
mock_dbt_runner.all_args_list[0].kwargs["command"] == expected_call_kwargs


def test_dbt_docs_local_operator_with_static_flag():
Expand Down
2 changes: 2 additions & 0 deletions tests/operators/test_virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from cosmos.config import ProfileConfig

from cosmos.profiles import PostgresUserPasswordProfileMapping
from cosmos.constants import InvocationMode

profile_config = ProfileConfig(
profile_name="default",
Expand Down Expand Up @@ -53,6 +54,7 @@ def test_run_command(
py_system_site_packages=False,
py_requirements=["dbt-postgres==1.6.0b1"],
emit_datasets=False,
invocation_mode=InvocationMode.SUBPROCESS,
)
assert venv_operator._venv_tmp_dir is None # Otherwise we are creating empty directories during DAG parsing time
# and not deleting them
Expand Down
2 changes: 1 addition & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_render_config_env_vars_deprecated():
"execution_mode, expectation",
[
(ExecutionMode.LOCAL, does_not_raise()),
(ExecutionMode.VIRTUALENV, does_not_raise()),
(ExecutionMode.VIRTUALENV, pytest.raises(CosmosValueError)),
(ExecutionMode.KUBERNETES, pytest.raises(CosmosValueError)),
(ExecutionMode.DOCKER, pytest.raises(CosmosValueError)),
(ExecutionMode.AZURE_CONTAINER_INSTANCE, pytest.raises(CosmosValueError)),
Expand Down

0 comments on commit fd92032

Please sign in to comment.