diff --git a/cosmos/profiles/__init__.py b/cosmos/profiles/__init__.py index dae6e2c04..e75b6c25e 100644 --- a/cosmos/profiles/__init__.py +++ b/cosmos/profiles/__init__.py @@ -20,6 +20,7 @@ from .trino.certificate import TrinoCertificateProfileMapping from .trino.jwt import TrinoJWTProfileMapping from .trino.ldap import TrinoLDAPProfileMapping +from .vertica.user_pass import VerticaUserPasswordProfileMapping profile_mappings: list[Type[BaseProfileMapping]] = [ AthenaAccessKeyProfileMapping, @@ -36,6 +37,7 @@ TrinoLDAPProfileMapping, TrinoCertificateProfileMapping, TrinoJWTProfileMapping, + VerticaUserPasswordProfileMapping, ] @@ -72,4 +74,5 @@ def get_automatic_profile_mapping( "TrinoLDAPProfileMapping", "TrinoCertificateProfileMapping", "TrinoJWTProfileMapping", + "VerticaUserPasswordProfileMapping", ] diff --git a/cosmos/profiles/vertica/__init__.py b/cosmos/profiles/vertica/__init__.py new file mode 100644 index 000000000..4a88f2edd --- /dev/null +++ b/cosmos/profiles/vertica/__init__.py @@ -0,0 +1,5 @@ +"Vertica Airflow connection -> dbt profile mappings" + +from .user_pass import VerticaUserPasswordProfileMapping + +__all__ = ["VerticaUserPasswordProfileMapping"] diff --git a/cosmos/profiles/vertica/user_pass.py b/cosmos/profiles/vertica/user_pass.py new file mode 100644 index 000000000..494185e05 --- /dev/null +++ b/cosmos/profiles/vertica/user_pass.py @@ -0,0 +1,76 @@ +"Maps Airflow Vertica connections using user + password authentication to dbt profiles." +from __future__ import annotations + +from typing import Any + +from ..base import BaseProfileMapping + + +class VerticaUserPasswordProfileMapping(BaseProfileMapping): + """ + Maps Airflow Vertica connections using user + password authentication to dbt profiles. + https://docs.getdbt.com/reference/warehouse-setups/vertica-setup + https://airflow.apache.org/docs/apache-airflow-providers-vertica/stable/connections/vertica.html + """ + + airflow_connection_type: str = "vertica" + dbt_profile_type: str = "vertica" + + required_fields = [ + "host", + "user", + "password", + "database", + "schema", + ] + secret_fields = [ + "password", + ] + airflow_param_mapping = { + "host": "host", + "user": "login", + "password": "password", + "port": "port", + "schema": "schema", + "database": "extra.database", + "autocommit": "extra.autocommit", + "backup_server_node": "extra.backup_server_node", + "binary_transfer": "extra.binary_transfer", + "connection_load_balance": "extra.connection_load_balance", + "connection_timeout": "extra.connection_timeout", + "disable_copy_local": "extra.disable_copy_local", + "kerberos_host_name": "extra.kerberos_host_name", + "kerberos_service_name": "extra.kerberos_service_name", + "log_level": "extra.log_level", + "log_path": "extra.log_path", + "oauth_access_token": "extra.oauth_access_token", + "request_complex_types": "extra.request_complex_types", + "session_label": "extra.session_label", + "ssl": "extra.ssl", + "unicode_error": "extra.unicode_error", + "use_prepared_statements": "extra.use_prepared_statements", + "workload": "extra.workload", + } + + @property + def profile(self) -> dict[str, Any | None]: + "Gets profile. The password is stored in an environment variable." + profile = { + "port": 5433, + **self.mapped_params, + **self.profile_args, + # password should always get set as env var + "password": self.get_env_var_format("password"), + } + + return self.filter_null(profile) + + @property + def mock_profile(self) -> dict[str, Any | None]: + "Gets mock profile. Defaults port to 5433." + parent_mock = super().mock_profile + + return { + "port": 5433, + **parent_mock, + } diff --git a/pyproject.toml b/pyproject.toml index 512af2bbe..ef1fba2f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dbt-all = [ "dbt-redshift", "dbt-snowflake", "dbt-spark", + "dbt-vertica", ] dbt-athena = [ "dbt-athena-community", @@ -80,6 +81,9 @@ dbt-snowflake = [ dbt-spark = [ "dbt-spark", ] +dbt-vertica = [ + "dbt-vertica<=1.5.4", +] openlineage = [ "openlineage-integration-common", "openlineage-airflow", @@ -165,18 +169,18 @@ test = 'pytest -vv --durations=0 . -m "not integration" --ignore=tests/test_exam test-cov = """pytest -vv --cov=cosmos --cov-report=term-missing --cov-report=xml --durations=0 -m "not integration" --ignore=tests/test_example_dags.py --ignore=tests/test_example_dags_no_connections.py""" # we install using the following workaround to overcome installation conflicts, such as: # apache-airflow 2.3.0 and dbt-core [0.13.0 - 1.5.2] and jinja2>=3.0.0 because these package versions have conflicting dependencies -test-integration-setup = """pip uninstall -y dbt-core dbt-databricks dbt-sqlite dbt-postgres dbt-sqlite; \ +test-integration-setup = """pip uninstall dbt-postgres dbt-databricks dbt-vertica; \ rm -rf airflow.*; \ airflow db init; \ -pip install 'dbt-core' 'dbt-databricks' 'dbt-postgres' 'openlineage-airflow'""" -test-integration = """pytest -vv \ +pip install 'dbt-core' 'dbt-databricks' 'dbt-postgres' 'dbt-vertica' 'openlineage-airflow'""" +test-integration = """rm -rf dbt/jaffle_shop/dbt_packages; +pytest -vv \ --cov=cosmos \ --cov-report=term-missing \ --cov-report=xml \ --durations=0 \ -m integration \ --k 'not (sqlite or example_cosmos_sources or example_cosmos_python_models or example_virtualenv or cosmos_manifest_example)' -""" +-k 'not (sqlite or example_cosmos_sources or example_cosmos_python_models or example_virtualenv or cosmos_manifest_example)'""" test-integration-expensive = """pytest -vv \ --cov=cosmos \ --cov-report=term-missing \ diff --git a/tests/profiles/vertica/test_vertica_user_pass.py b/tests/profiles/vertica/test_vertica_user_pass.py new file mode 100644 index 000000000..953a3c553 --- /dev/null +++ b/tests/profiles/vertica/test_vertica_user_pass.py @@ -0,0 +1,191 @@ +"Tests for the vertica profile." + +from unittest.mock import patch + +import pytest +from airflow.models.connection import Connection + +from cosmos.profiles import get_automatic_profile_mapping +from cosmos.profiles.vertica.user_pass import ( + VerticaUserPasswordProfileMapping, +) + + +@pytest.fixture() +def mock_vertica_conn(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_vertica_connection", + conn_type="vertica", + host="my_host", + login="my_user", + password="my_password", + port=5433, + schema="my_schema", + extra='{"database": "my_database"}', + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +@pytest.fixture() +def mock_vertica_conn_custom_port(): # type: ignore + """ + Sets the connection as an environment variable. + """ + conn = Connection( + conn_id="my_vertica_connection", + conn_type="vertica", + host="my_host", + login="my_user", + password="my_password", + port=7472, + schema="my_schema", + extra='{"database": "my_database"}', + ) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + yield conn + + +def test_connection_claiming() -> None: + """ + Tests that the vertica profile mapping claims the correct connection type. + """ + # should only claim when: + # - conn_type == vertica + # and the following exist: + # - host + # - user + # - password + # - port + # - database or database + # - schema + potential_values = { + "conn_type": "vertica", + "host": "my_host", + "login": "my_user", + "password": "my_password", + "schema": "my_schema", + "extra": '{"database": "my_database"}', + } + + # if we're missing any of the values, it shouldn't claim + for key in potential_values: + values = potential_values.copy() + del values[key] + conn = Connection(**values) # type: ignore + + print("testing with", values) + + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = VerticaUserPasswordProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # also test when there's no database + conn = Connection(**potential_values) # type: ignore + conn.extra = "" + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = VerticaUserPasswordProfileMapping(conn) + assert not profile_mapping.can_claim_connection() + + # if we have them all, it should claim + conn = Connection(**potential_values) # type: ignore + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + profile_mapping = VerticaUserPasswordProfileMapping(conn) + assert profile_mapping.can_claim_connection() + + +def test_profile_mapping_selected( + mock_vertica_conn: Connection, +) -> None: + """ + Tests that the correct profile mapping is selected. + """ + profile_mapping = get_automatic_profile_mapping( + mock_vertica_conn.conn_id, + {"schema": "my_schema"}, + ) + assert isinstance(profile_mapping, VerticaUserPasswordProfileMapping) + + +def test_mock_profile() -> None: + """ + Tests that the mock profile port value get set correctly. + """ + profile = VerticaUserPasswordProfileMapping("mock_conn_id") + assert profile.mock_profile.get("port") == 5433 + + +def test_profile_mapping_keeps_custom_port(mock_vertica_conn_custom_port: Connection) -> None: + profile = VerticaUserPasswordProfileMapping(mock_vertica_conn_custom_port.conn_id, {"schema": "my_schema"}) + assert profile.profile["port"] == 7472 + + +def test_profile_args( + mock_vertica_conn: Connection, +) -> None: + """ + Tests that the profile values get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_vertica_conn.conn_id, + profile_args={"schema": "my_schema"}, + ) + assert profile_mapping.profile_args == { + "schema": "my_schema", + } + + assert profile_mapping.profile == { + "type": mock_vertica_conn.conn_type, + "host": mock_vertica_conn.host, + "user": mock_vertica_conn.login, + "password": "{{ env_var('COSMOS_CONN_VERTICA_PASSWORD') }}", + "port": mock_vertica_conn.port, + "schema": "my_schema", + "database": mock_vertica_conn.extra_dejson.get("database"), + } + + +def test_profile_args_overrides( + mock_vertica_conn: Connection, +) -> None: + """ + Tests that you can override the profile values. + """ + profile_mapping = get_automatic_profile_mapping( + mock_vertica_conn.conn_id, + profile_args={"schema": "my_schema", "database": "my_db_override"}, + ) + assert profile_mapping.profile_args == { + "schema": "my_schema", + "database": "my_db_override", + } + + assert profile_mapping.profile == { + "type": mock_vertica_conn.conn_type, + "host": mock_vertica_conn.host, + "user": mock_vertica_conn.login, + "password": "{{ env_var('COSMOS_CONN_VERTICA_PASSWORD') }}", + "port": mock_vertica_conn.port, + "database": "my_db_override", + "schema": "my_schema", + } + + +def test_profile_env_vars( + mock_vertica_conn: Connection, +) -> None: + """ + Tests that the environment variables get set correctly. + """ + profile_mapping = get_automatic_profile_mapping( + mock_vertica_conn.conn_id, + profile_args={"schema": "my_schema"}, + ) + assert profile_mapping.env_vars == { + "COSMOS_CONN_VERTICA_PASSWORD": mock_vertica_conn.password, + }