-
Notifications
You must be signed in to change notification settings - Fork 177
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
435a699
commit c2315d3
Showing
5 changed files
with
227 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"Athena Airflow connection -> dbt profile mappings" | ||
|
||
from .access_key import AthenaAccessKeyProfileMapping | ||
|
||
__all__ = ["AthenaAccessKeyProfileMapping"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
"Maps Airflow AWS connections to a dbt Athena profile using an access key id and secret access key." | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from ..base import BaseProfileMapping | ||
|
||
|
||
class AthenaAccessKeyProfileMapping(BaseProfileMapping): | ||
""" | ||
Maps Airflow AWS connections to a dbt Athena profile using an access key id and secret access key. | ||
https://docs.getdbt.com/docs/core/connect-data-platform/athena-setup | ||
https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/connections/aws.html | ||
""" | ||
|
||
airflow_connection_type: str = "aws" | ||
dbt_profile_type: str = "athena" | ||
|
||
required_fields = [ | ||
"aws_access_key_id", | ||
"aws_secret_access_key", | ||
"database", | ||
"region_name", | ||
"s3_staging_dir", | ||
"schema", | ||
] | ||
secret_fields = [ | ||
"aws_secret_access_key", | ||
] | ||
airflow_param_mapping = { | ||
"aws_access_key_id": "login", | ||
"aws_secret_access_key": "password", | ||
"aws_profile_name": "extra.aws_profile_name", | ||
"database": "extra.database", | ||
"debug_query_state": "extra.debug_query_state", | ||
"lf_tags_database": "extra.lf_tags_database", | ||
"num_retries": "extra.num_retries", | ||
"poll_interval": "extra.poll_interval", | ||
"region_name": "extra.region_name", | ||
"s3_data_dir": "extra.s3_data_dir", | ||
"s3_data_naming": "extra.s3_data_naming", | ||
"s3_staging_dir": "extra.s3_staging_dir", | ||
"schema": "extra.schema", | ||
"seed_s3_upload_args": "extra.seed_s3_upload_args", | ||
"work_group": "extra.work_group", | ||
} | ||
|
||
@property | ||
def profile(self) -> dict[str, Any | None]: | ||
"Gets profile. The password is stored in an environment variable." | ||
profile = { | ||
**self.mapped_params, | ||
**self.profile_args, | ||
# aws_secret_access_key should always get set as env var | ||
"aws_secret_access_key": self.get_env_var_format("aws_secret_access_key"), | ||
} | ||
return self.filter_null(profile) | ||
|
||
@property | ||
def mock_profile(self) -> dict[str, Any | None]: | ||
"Gets mock profile." | ||
parent_mock = super().mock_profile | ||
|
||
return { | ||
**parent_mock, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
"Tests for the Athena profile." | ||
|
||
import json | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from airflow.models.connection import Connection | ||
|
||
from cosmos.profiles import get_automatic_profile_mapping | ||
from cosmos.profiles.athena.access_key import AthenaAccessKeyProfileMapping | ||
|
||
|
||
@pytest.fixture() | ||
def mock_athena_conn(): # type: ignore | ||
""" | ||
Sets the connection as an environment variable. | ||
""" | ||
conn = Connection( | ||
conn_id="my_athena_connection", | ||
conn_type="aws", | ||
login="my_aws_access_key_id", | ||
password="my_aws_secret_key", | ||
extra=json.dumps( | ||
{ | ||
"database": "my_database", | ||
"region_name": "my_region", | ||
"s3_staging_dir": "s3://my_bucket/dbt/", | ||
"schema": "my_schema", | ||
} | ||
) | ||
) | ||
|
||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
yield conn | ||
|
||
def test_athena_connection_claiming() -> None: | ||
""" | ||
Tests that the Athena profile mapping claims the correct connection type. | ||
""" | ||
# should only claim when: | ||
# - conn_type == aws | ||
# and the following exist: | ||
# - login | ||
# - password | ||
# - database | ||
# - region_name | ||
# - s3_staging_dir | ||
# - schema | ||
potential_values = { | ||
"conn_type": "aws", | ||
"login": "my_aws_access_key_id", | ||
"password": "my_aws_secret_key", | ||
"extra": json.dumps( | ||
{ | ||
"database": "my_database", | ||
"region_name": "my_region", | ||
"s3_staging_dir": "s3://my_bucket/dbt/", | ||
"schema": "my_schema", | ||
} | ||
) | ||
} | ||
|
||
# if we're missing any of the values, it shouldn't claim | ||
for key in potential_values: | ||
values = potential_values.copy() | ||
del values[key] | ||
conn = Connection(**values) # type: ignore | ||
|
||
print("testing with", values) | ||
|
||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
# should raise an InvalidMappingException | ||
profile_mapping = AthenaAccessKeyProfileMapping(conn, {}) | ||
assert not profile_mapping.can_claim_connection() | ||
|
||
# if we have them all, it should claim | ||
conn = Connection(**potential_values) # type: ignore | ||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
profile_mapping = AthenaAccessKeyProfileMapping(conn, {}) | ||
assert profile_mapping.can_claim_connection() | ||
|
||
def test_athena_profile_mapping_selected( | ||
mock_athena_conn: Connection, | ||
) -> None: | ||
""" | ||
Tests that the correct profile mapping is selected for Athena. | ||
""" | ||
profile_mapping = get_automatic_profile_mapping( | ||
mock_athena_conn.conn_id, | ||
) | ||
assert isinstance(profile_mapping, AthenaAccessKeyProfileMapping) | ||
|
||
def test_athena_profile_args( | ||
mock_athena_conn: Connection, | ||
) -> None: | ||
""" | ||
Tests that the profile values get set correctly for Athena. | ||
""" | ||
profile_mapping = get_automatic_profile_mapping( | ||
mock_athena_conn.conn_id, | ||
) | ||
|
||
assert profile_mapping.profile == { | ||
"type": "athena", | ||
"aws_access_key_id": mock_athena_conn.login, | ||
"aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", | ||
"database": mock_athena_conn.extra_dejson.get("database"), | ||
"region_name": mock_athena_conn.extra_dejson.get("region_name"), | ||
"s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"), | ||
"schema": mock_athena_conn.extra_dejson.get("schema"), | ||
} | ||
|
||
def test_athena_profile_args_overrides( | ||
mock_athena_conn: Connection, | ||
) -> None: | ||
""" | ||
Tests that you can override the profile values for Athena. | ||
""" | ||
profile_mapping = get_automatic_profile_mapping( | ||
mock_athena_conn.conn_id, | ||
profile_args={"schema": "my_custom_schema", "database": "my_custom_db"}, | ||
) | ||
assert profile_mapping.profile_args == { | ||
"schema": "my_custom_schema", | ||
"database": "my_custom_db", | ||
} | ||
|
||
assert profile_mapping.profile == { | ||
"type": "athena", | ||
"aws_access_key_id": mock_athena_conn.login, | ||
"aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", | ||
"database": "my_custom_db", | ||
"region_name": mock_athena_conn.extra_dejson.get("region_name"), | ||
"s3_staging_dir": mock_athena_conn.extra_dejson.get("s3_staging_dir"), | ||
"schema": "my_custom_schema", | ||
} | ||
|
||
def test_athena_profile_env_vars( | ||
mock_athena_conn: Connection, | ||
) -> None: | ||
""" | ||
Tests that the environment variables get set correctly for Athena. | ||
""" | ||
profile_mapping = get_automatic_profile_mapping( | ||
mock_athena_conn.conn_id, | ||
) | ||
assert profile_mapping.env_vars == { | ||
"COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_athena_conn.password, | ||
} |