Skip to content

Commit

Permalink
Merge branch 'main' into feature/cache-virtualenv
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana authored Oct 18, 2023
2 parents ad024df + c82e6bc commit c31b3e9
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 5 deletions.
3 changes: 3 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
from .trino.ldap import TrinoLDAPProfileMapping
from .vertica.user_pass import VerticaUserPasswordProfileMapping

profile_mappings: list[Type[BaseProfileMapping]] = [
AthenaAccessKeyProfileMapping,
Expand All @@ -36,6 +37,7 @@
TrinoLDAPProfileMapping,
TrinoCertificateProfileMapping,
TrinoJWTProfileMapping,
VerticaUserPasswordProfileMapping,
]


Expand Down Expand Up @@ -72,4 +74,5 @@ def get_automatic_profile_mapping(
"TrinoLDAPProfileMapping",
"TrinoCertificateProfileMapping",
"TrinoJWTProfileMapping",
"VerticaUserPasswordProfileMapping",
]
5 changes: 5 additions & 0 deletions cosmos/profiles/vertica/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"Vertica Airflow connection -> dbt profile mappings"

from .user_pass import VerticaUserPasswordProfileMapping

__all__ = ["VerticaUserPasswordProfileMapping"]
76 changes: 76 additions & 0 deletions cosmos/profiles/vertica/user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"Maps Airflow Vertica connections using user + password authentication to dbt profiles."
from __future__ import annotations

from typing import Any

from ..base import BaseProfileMapping


class VerticaUserPasswordProfileMapping(BaseProfileMapping):
"""
Maps Airflow Vertica connections using user + password authentication to dbt profiles.
https://docs.getdbt.com/reference/warehouse-setups/vertica-setup
https://airflow.apache.org/docs/apache-airflow-providers-vertica/stable/connections/vertica.html
"""

airflow_connection_type: str = "vertica"
dbt_profile_type: str = "vertica"

required_fields = [
"host",
"user",
"password",
"database",
"schema",
]
secret_fields = [
"password",
]
airflow_param_mapping = {
"host": "host",
"user": "login",
"password": "password",
"port": "port",
"schema": "schema",
"database": "extra.database",
"autocommit": "extra.autocommit",
"backup_server_node": "extra.backup_server_node",
"binary_transfer": "extra.binary_transfer",
"connection_load_balance": "extra.connection_load_balance",
"connection_timeout": "extra.connection_timeout",
"disable_copy_local": "extra.disable_copy_local",
"kerberos_host_name": "extra.kerberos_host_name",
"kerberos_service_name": "extra.kerberos_service_name",
"log_level": "extra.log_level",
"log_path": "extra.log_path",
"oauth_access_token": "extra.oauth_access_token",
"request_complex_types": "extra.request_complex_types",
"session_label": "extra.session_label",
"ssl": "extra.ssl",
"unicode_error": "extra.unicode_error",
"use_prepared_statements": "extra.use_prepared_statements",
"workload": "extra.workload",
}

@property
def profile(self) -> dict[str, Any | None]:
"Gets profile. The password is stored in an environment variable."
profile = {
"port": 5433,
**self.mapped_params,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

return self.filter_null(profile)

@property
def mock_profile(self) -> dict[str, Any | None]:
"Gets mock profile. Defaults port to 5433."
parent_mock = super().mock_profile

return {
"port": 5433,
**parent_mock,
}
14 changes: 9 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dbt-all = [
"dbt-redshift",
"dbt-snowflake",
"dbt-spark",
"dbt-vertica",
]
dbt-athena = [
"dbt-athena-community",
Expand All @@ -80,6 +81,9 @@ dbt-snowflake = [
dbt-spark = [
"dbt-spark",
]
dbt-vertica = [
"dbt-vertica<=1.5.4",
]
openlineage = [
"openlineage-integration-common",
"openlineage-airflow",
Expand Down Expand Up @@ -165,18 +169,18 @@ test = 'pytest -vv --durations=0 . -m "not integration" --ignore=tests/test_exam
test-cov = """pytest -vv --cov=cosmos --cov-report=term-missing --cov-report=xml --durations=0 -m "not integration" --ignore=tests/test_example_dags.py --ignore=tests/test_example_dags_no_connections.py"""
# we install using the following workaround to overcome installation conflicts, such as:
# apache-airflow 2.3.0 and dbt-core [0.13.0 - 1.5.2] and jinja2>=3.0.0 because these package versions have conflicting dependencies
test-integration-setup = """pip uninstall -y dbt-core dbt-databricks dbt-sqlite dbt-postgres dbt-sqlite; \
test-integration-setup = """pip uninstall dbt-postgres dbt-databricks dbt-vertica; \
rm -rf airflow.*; \
airflow db init; \
pip install 'dbt-core' 'dbt-databricks' 'dbt-postgres' 'openlineage-airflow'"""
test-integration = """pytest -vv \
pip install 'dbt-core' 'dbt-databricks' 'dbt-postgres' 'dbt-vertica' 'openlineage-airflow'"""
test-integration = """rm -rf dbt/jaffle_shop/dbt_packages;
pytest -vv \
--cov=cosmos \
--cov-report=term-missing \
--cov-report=xml \
--durations=0 \
-m integration \
-k 'not (sqlite or example_cosmos_sources or example_cosmos_python_models or example_virtualenv or cosmos_manifest_example)'
"""
-k 'not (sqlite or example_cosmos_sources or example_cosmos_python_models or example_virtualenv or cosmos_manifest_example)'"""
test-integration-expensive = """pytest -vv \
--cov=cosmos \
--cov-report=term-missing \
Expand Down
191 changes: 191 additions & 0 deletions tests/profiles/vertica/test_vertica_user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"Tests for the vertica profile."

from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.vertica.user_pass import (
VerticaUserPasswordProfileMapping,
)


@pytest.fixture()
def mock_vertica_conn(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_vertica_connection",
conn_type="vertica",
host="my_host",
login="my_user",
password="my_password",
port=5433,
schema="my_schema",
extra='{"database": "my_database"}',
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


@pytest.fixture()
def mock_vertica_conn_custom_port(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_vertica_connection",
conn_type="vertica",
host="my_host",
login="my_user",
password="my_password",
port=7472,
schema="my_schema",
extra='{"database": "my_database"}',
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_connection_claiming() -> None:
"""
Tests that the vertica profile mapping claims the correct connection type.
"""
# should only claim when:
# - conn_type == vertica
# and the following exist:
# - host
# - user
# - password
# - port
# - database or database
# - schema
potential_values = {
"conn_type": "vertica",
"host": "my_host",
"login": "my_user",
"password": "my_password",
"schema": "my_schema",
"extra": '{"database": "my_database"}',
}

# if we're missing any of the values, it shouldn't claim
for key in potential_values:
values = potential_values.copy()
del values[key]
conn = Connection(**values) # type: ignore

print("testing with", values)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = VerticaUserPasswordProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# also test when there's no database
conn = Connection(**potential_values) # type: ignore
conn.extra = ""
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = VerticaUserPasswordProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# if we have them all, it should claim
conn = Connection(**potential_values) # type: ignore
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = VerticaUserPasswordProfileMapping(conn)
assert profile_mapping.can_claim_connection()


def test_profile_mapping_selected(
mock_vertica_conn: Connection,
) -> None:
"""
Tests that the correct profile mapping is selected.
"""
profile_mapping = get_automatic_profile_mapping(
mock_vertica_conn.conn_id,
{"schema": "my_schema"},
)
assert isinstance(profile_mapping, VerticaUserPasswordProfileMapping)


def test_mock_profile() -> None:
"""
Tests that the mock profile port value get set correctly.
"""
profile = VerticaUserPasswordProfileMapping("mock_conn_id")
assert profile.mock_profile.get("port") == 5433


def test_profile_mapping_keeps_custom_port(mock_vertica_conn_custom_port: Connection) -> None:
profile = VerticaUserPasswordProfileMapping(mock_vertica_conn_custom_port.conn_id, {"schema": "my_schema"})
assert profile.profile["port"] == 7472


def test_profile_args(
mock_vertica_conn: Connection,
) -> None:
"""
Tests that the profile values get set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_vertica_conn.conn_id,
profile_args={"schema": "my_schema"},
)
assert profile_mapping.profile_args == {
"schema": "my_schema",
}

assert profile_mapping.profile == {
"type": mock_vertica_conn.conn_type,
"host": mock_vertica_conn.host,
"user": mock_vertica_conn.login,
"password": "{{ env_var('COSMOS_CONN_VERTICA_PASSWORD') }}",
"port": mock_vertica_conn.port,
"schema": "my_schema",
"database": mock_vertica_conn.extra_dejson.get("database"),
}


def test_profile_args_overrides(
mock_vertica_conn: Connection,
) -> None:
"""
Tests that you can override the profile values.
"""
profile_mapping = get_automatic_profile_mapping(
mock_vertica_conn.conn_id,
profile_args={"schema": "my_schema", "database": "my_db_override"},
)
assert profile_mapping.profile_args == {
"schema": "my_schema",
"database": "my_db_override",
}

assert profile_mapping.profile == {
"type": mock_vertica_conn.conn_type,
"host": mock_vertica_conn.host,
"user": mock_vertica_conn.login,
"password": "{{ env_var('COSMOS_CONN_VERTICA_PASSWORD') }}",
"port": mock_vertica_conn.port,
"database": "my_db_override",
"schema": "my_schema",
}


def test_profile_env_vars(
mock_vertica_conn: Connection,
) -> None:
"""
Tests that the environment variables get set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_vertica_conn.conn_id,
profile_args={"schema": "my_schema"},
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_VERTICA_PASSWORD": mock_vertica_conn.password,
}

0 comments on commit c31b3e9

Please sign in to comment.